diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-05-03 09:12:24 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-05-03 09:12:24 +0000 |
commit | 98d5537435b2951b36c45f1fda667fa27c165794 (patch) | |
tree | d26b4dfa6cf91847100fe10a94a04dcc2ad36a86 /sqlglot/parser.py | |
parent | Adding upstream version 11.5.2. (diff) | |
download | sqlglot-98d5537435b2951b36c45f1fda667fa27c165794.tar.xz sqlglot-98d5537435b2951b36c45f1fda667fa27c165794.zip |
Adding upstream version 11.7.1.upstream/11.7.1
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/parser.py')
-rw-r--r-- | sqlglot/parser.py | 382 |
1 files changed, 292 insertions, 90 deletions
diff --git a/sqlglot/parser.py b/sqlglot/parser.py index b3b899c..abb23ad 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -18,8 +18,13 @@ from sqlglot.trie import in_trie, new_trie logger = logging.getLogger("sqlglot") +E = t.TypeVar("E", bound=exp.Expression) + def parse_var_map(args: t.Sequence) -> exp.Expression: + if len(args) == 1 and args[0].is_star: + return exp.StarMap(this=args[0]) + keys = [] values = [] for i in range(0, len(args), 2): @@ -108,6 +113,8 @@ class Parser(metaclass=_Parser): TokenType.CURRENT_USER: exp.CurrentUser, } + JOIN_HINTS: t.Set[str] = set() + NESTED_TYPE_TOKENS = { TokenType.ARRAY, TokenType.MAP, @@ -145,6 +152,7 @@ class Parser(metaclass=_Parser): TokenType.DATETIME, TokenType.DATE, TokenType.DECIMAL, + TokenType.BIGDECIMAL, TokenType.UUID, TokenType.GEOGRAPHY, TokenType.GEOMETRY, @@ -221,8 +229,10 @@ class Parser(metaclass=_Parser): TokenType.FORMAT, TokenType.FULL, TokenType.IF, + TokenType.IS, TokenType.ISNULL, TokenType.INTERVAL, + TokenType.KEEP, TokenType.LAZY, TokenType.LEADING, TokenType.LEFT, @@ -235,6 +245,7 @@ class Parser(metaclass=_Parser): TokenType.ONLY, TokenType.OPTIONS, TokenType.ORDINALITY, + TokenType.OVERWRITE, TokenType.PARTITION, TokenType.PERCENT, TokenType.PIVOT, @@ -266,6 +277,8 @@ class Parser(metaclass=_Parser): *NO_PAREN_FUNCTIONS, } + INTERVAL_VARS = ID_VAR_TOKENS - {TokenType.END} + TABLE_ALIAS_TOKENS = ID_VAR_TOKENS - { TokenType.APPLY, TokenType.FULL, @@ -276,6 +289,8 @@ class Parser(metaclass=_Parser): TokenType.WINDOW, } + COMMENT_TABLE_ALIAS_TOKENS = TABLE_ALIAS_TOKENS - {TokenType.IS} + UPDATE_ALIAS_TOKENS = TABLE_ALIAS_TOKENS - {TokenType.SET} TRIM_TYPES = {TokenType.LEADING, TokenType.TRAILING, TokenType.BOTH} @@ -400,7 +415,7 @@ class Parser(metaclass=_Parser): COLUMN_OPERATORS = { TokenType.DOT: None, TokenType.DCOLON: lambda self, this, to: self.expression( - exp.Cast, + exp.Cast if self.STRICT_CAST else exp.TryCast, this=this, to=to, ), @@ -560,7 +575,7 @@ class Parser(metaclass=_Parser): ), "DEFINER": lambda self: self._parse_definer(), "DETERMINISTIC": lambda self: self.expression( - exp.VolatilityProperty, this=exp.Literal.string("IMMUTABLE") + exp.StabilityProperty, this=exp.Literal.string("IMMUTABLE") ), "DISTKEY": lambda self: self._parse_distkey(), "DISTSTYLE": lambda self: self._parse_property_assignment(exp.DistStyleProperty), @@ -571,7 +586,7 @@ class Parser(metaclass=_Parser): "FREESPACE": lambda self: self._parse_freespace(), "GLOBAL": lambda self: self._parse_temporary(global_=True), "IMMUTABLE": lambda self: self.expression( - exp.VolatilityProperty, this=exp.Literal.string("IMMUTABLE") + exp.StabilityProperty, this=exp.Literal.string("IMMUTABLE") ), "JOURNAL": lambda self: self._parse_journal( no=self._prev.text.upper() == "NO", dual=self._prev.text.upper() == "DUAL" @@ -600,20 +615,20 @@ class Parser(metaclass=_Parser): "PARTITIONED_BY": lambda self: self._parse_partitioned_by(), "RETURNS": lambda self: self._parse_returns(), "ROW": lambda self: self._parse_row(), + "ROW_FORMAT": lambda self: self._parse_property_assignment(exp.RowFormatProperty), "SET": lambda self: self.expression(exp.SetProperty, multi=False), "SORTKEY": lambda self: self._parse_sortkey(), "STABLE": lambda self: self.expression( - exp.VolatilityProperty, this=exp.Literal.string("STABLE") + exp.StabilityProperty, this=exp.Literal.string("STABLE") ), - "STORED": lambda self: self._parse_property_assignment(exp.FileFormatProperty), + "STORED": lambda self: self._parse_stored(), "TABLE_FORMAT": lambda self: self._parse_property_assignment(exp.TableFormatProperty), "TBLPROPERTIES": lambda self: self._parse_wrapped_csv(self._parse_property), + "TEMP": lambda self: self._parse_temporary(global_=False), "TEMPORARY": lambda self: self._parse_temporary(global_=False), "TRANSIENT": lambda self: self.expression(exp.TransientProperty), "USING": lambda self: self._parse_property_assignment(exp.TableFormatProperty), - "VOLATILE": lambda self: self.expression( - exp.VolatilityProperty, this=exp.Literal.string("VOLATILE") - ), + "VOLATILE": lambda self: self._parse_volatile_property(), "WITH": lambda self: self._parse_with_property(), } @@ -648,8 +663,11 @@ class Parser(metaclass=_Parser): "LIKE": lambda self: self._parse_create_like(), "NOT": lambda self: self._parse_not_constraint(), "NULL": lambda self: self.expression(exp.NotNullColumnConstraint, allow_null=True), + "ON": lambda self: self._match(TokenType.UPDATE) + and self.expression(exp.OnUpdateColumnConstraint, this=self._parse_function()), "PATH": lambda self: self.expression(exp.PathColumnConstraint, this=self._parse_string()), "PRIMARY KEY": lambda self: self._parse_primary_key(), + "REFERENCES": lambda self: self._parse_references(match=False), "TITLE": lambda self: self.expression( exp.TitleColumnConstraint, this=self._parse_var_or_string() ), @@ -668,9 +686,14 @@ class Parser(metaclass=_Parser): SCHEMA_UNNAMED_CONSTRAINTS = {"CHECK", "FOREIGN KEY", "LIKE", "PRIMARY KEY", "UNIQUE"} NO_PAREN_FUNCTION_PARSERS = { + TokenType.ANY: lambda self: self.expression(exp.Any, this=self._parse_bitwise()), TokenType.CASE: lambda self: self._parse_case(), TokenType.IF: lambda self: self._parse_if(), - TokenType.ANY: lambda self: self.expression(exp.Any, this=self._parse_bitwise()), + TokenType.NEXT_VALUE_FOR: lambda self: self.expression( + exp.NextValueFor, + this=self._parse_column(), + order=self._match(TokenType.OVER) and self._parse_wrapped(self._parse_order), + ), } FUNCTION_PARSERS: t.Dict[str, t.Callable] = { @@ -715,6 +738,8 @@ class Parser(metaclass=_Parser): SHOW_PARSERS: t.Dict[str, t.Callable] = {} + TYPE_LITERAL_PARSERS: t.Dict[exp.DataType.Type, t.Callable] = {} + MODIFIABLES = (exp.Subquery, exp.Subqueryable, exp.Table) TRANSACTION_KIND = {"DEFERRED", "IMMEDIATE", "EXCLUSIVE"} @@ -731,6 +756,7 @@ class Parser(metaclass=_Parser): INSERT_ALTERNATIVES = {"ABORT", "FAIL", "IGNORE", "REPLACE", "ROLLBACK"} WINDOW_ALIAS_TOKENS = ID_VAR_TOKENS - {TokenType.ROWS} + WINDOW_BEFORE_PAREN_TOKENS = {TokenType.OVER} ADD_CONSTRAINT_TOKENS = {TokenType.CONSTRAINT, TokenType.PRIMARY_KEY, TokenType.FOREIGN_KEY} @@ -738,6 +764,9 @@ class Parser(metaclass=_Parser): CONVERT_TYPE_FIRST = False + QUOTED_PIVOT_COLUMNS: t.Optional[bool] = None + PREFIXED_PIVOT_COLUMNS = False + LOG_BASE_FIRST = True LOG_DEFAULTS_TO_LN = False @@ -895,8 +924,8 @@ class Parser(metaclass=_Parser): error level setting. """ token = token or self._curr or self._prev or Token.string("") - start = self._find_token(token) - end = start + len(token.text) + start = token.start + end = token.end start_context = self.sql[max(start - self.error_message_context, 0) : start] highlight = self.sql[start:end] end_context = self.sql[end : end + self.error_message_context] @@ -918,8 +947,8 @@ class Parser(metaclass=_Parser): self.errors.append(error) def expression( - self, exp_class: t.Type[exp.Expression], comments: t.Optional[t.List[str]] = None, **kwargs - ) -> exp.Expression: + self, exp_class: t.Type[E], comments: t.Optional[t.List[str]] = None, **kwargs + ) -> E: """ Creates a new, validated Expression. @@ -958,22 +987,7 @@ class Parser(metaclass=_Parser): self.raise_error(error_message) def _find_sql(self, start: Token, end: Token) -> str: - return self.sql[self._find_token(start) : self._find_token(end) + len(end.text)] - - def _find_token(self, token: Token) -> int: - line = 1 - col = 1 - index = 0 - - while line < token.line or col < token.col: - if Tokenizer.WHITE_SPACE.get(self.sql[index]) == TokenType.BREAK: - line += 1 - col = 1 - else: - col += 1 - index += 1 - - return index + return self.sql[start.start : end.end] def _advance(self, times: int = 1) -> None: self._index += times @@ -990,7 +1004,7 @@ class Parser(metaclass=_Parser): if index != self._index: self._advance(index - self._index) - def _parse_command(self) -> exp.Expression: + def _parse_command(self) -> exp.Command: return self.expression(exp.Command, this=self._prev.text, expression=self._parse_string()) def _parse_comment(self, allow_exists: bool = True) -> exp.Expression: @@ -1007,7 +1021,7 @@ class Parser(metaclass=_Parser): if kind.token_type in (TokenType.FUNCTION, TokenType.PROCEDURE): this = self._parse_user_defined_function(kind=kind.token_type) elif kind.token_type == TokenType.TABLE: - this = self._parse_table() + this = self._parse_table(alias_tokens=self.COMMENT_TABLE_ALIAS_TOKENS) elif kind.token_type == TokenType.COLUMN: this = self._parse_column() else: @@ -1035,16 +1049,13 @@ class Parser(metaclass=_Parser): self._parse_query_modifiers(expression) return expression - def _parse_drop(self, default_kind: t.Optional[str] = None) -> t.Optional[exp.Expression]: + def _parse_drop(self) -> t.Optional[exp.Drop | exp.Command]: start = self._prev temporary = self._match(TokenType.TEMPORARY) materialized = self._match(TokenType.MATERIALIZED) kind = self._match_set(self.CREATABLES) and self._prev.text if not kind: - if default_kind: - kind = default_kind - else: - return self._parse_as_command(start) + return self._parse_as_command(start) return self.expression( exp.Drop, @@ -1055,6 +1066,7 @@ class Parser(metaclass=_Parser): materialized=materialized, cascade=self._match(TokenType.CASCADE), constraints=self._match_text_seq("CONSTRAINTS"), + purge=self._match_text_seq("PURGE"), ) def _parse_exists(self, not_: bool = False) -> t.Optional[bool]: @@ -1070,7 +1082,6 @@ class Parser(metaclass=_Parser): TokenType.OR, TokenType.REPLACE ) unique = self._match(TokenType.UNIQUE) - volatile = self._match(TokenType.VOLATILE) if self._match_pair(TokenType.TABLE, TokenType.FUNCTION, advance=False): self._match(TokenType.TABLE) @@ -1179,7 +1190,6 @@ class Parser(metaclass=_Parser): kind=create_token.text, replace=replace, unique=unique, - volatile=volatile, expression=expression, exists=exists, properties=properties, @@ -1225,6 +1235,21 @@ class Parser(metaclass=_Parser): return None + def _parse_stored(self) -> exp.Expression: + self._match(TokenType.ALIAS) + + input_format = self._parse_string() if self._match_text_seq("INPUTFORMAT") else None + output_format = self._parse_string() if self._match_text_seq("OUTPUTFORMAT") else None + + return self.expression( + exp.FileFormatProperty, + this=self.expression( + exp.InputOutputFormat, input_format=input_format, output_format=output_format + ) + if input_format or output_format + else self._parse_var_or_string() or self._parse_number() or self._parse_id_var(), + ) + def _parse_property_assignment(self, exp_class: t.Type[exp.Expression]) -> exp.Expression: self._match(TokenType.EQ) self._match(TokenType.ALIAS) @@ -1258,6 +1283,21 @@ class Parser(metaclass=_Parser): exp.FallbackProperty, no=no, protection=self._match_text_seq("PROTECTION") ) + def _parse_volatile_property(self) -> exp.Expression: + if self._index >= 2: + pre_volatile_token = self._tokens[self._index - 2] + else: + pre_volatile_token = None + + if pre_volatile_token and pre_volatile_token.token_type in ( + TokenType.CREATE, + TokenType.REPLACE, + TokenType.UNIQUE, + ): + return exp.VolatileProperty() + + return self.expression(exp.StabilityProperty, this=exp.Literal.string("VOLATILE")) + def _parse_with_property( self, ) -> t.Union[t.Optional[exp.Expression], t.List[t.Optional[exp.Expression]]]: @@ -1574,11 +1614,46 @@ class Parser(metaclass=_Parser): exists=self._parse_exists(), partition=self._parse_partition(), expression=self._parse_ddl_select(), + conflict=self._parse_on_conflict(), returning=self._parse_returning(), overwrite=overwrite, alternative=alternative, ) + def _parse_on_conflict(self) -> t.Optional[exp.Expression]: + conflict = self._match_text_seq("ON", "CONFLICT") + duplicate = self._match_text_seq("ON", "DUPLICATE", "KEY") + + if not (conflict or duplicate): + return None + + nothing = None + expressions = None + key = None + constraint = None + + if conflict: + if self._match_text_seq("ON", "CONSTRAINT"): + constraint = self._parse_id_var() + else: + key = self._parse_csv(self._parse_value) + + self._match_text_seq("DO") + if self._match_text_seq("NOTHING"): + nothing = True + else: + self._match(TokenType.UPDATE) + expressions = self._match(TokenType.SET) and self._parse_csv(self._parse_equality) + + return self.expression( + exp.OnConflict, + duplicate=duplicate, + expressions=expressions, + nothing=nothing, + key=key, + constraint=constraint, + ) + def _parse_returning(self) -> t.Optional[exp.Expression]: if not self._match(TokenType.RETURNING): return None @@ -1639,7 +1714,7 @@ class Parser(metaclass=_Parser): return self.expression( exp.Delete, - this=self._parse_table(schema=True), + this=self._parse_table(), using=self._parse_csv(lambda: self._match(TokenType.USING) and self._parse_table()), where=self._parse_where(), returning=self._parse_returning(), @@ -1792,6 +1867,7 @@ class Parser(metaclass=_Parser): if not skip_with_token and not self._match(TokenType.WITH): return None + comments = self._prev_comments recursive = self._match(TokenType.RECURSIVE) expressions = [] @@ -1803,7 +1879,9 @@ class Parser(metaclass=_Parser): else: self._match(TokenType.WITH) - return self.expression(exp.With, expressions=expressions, recursive=recursive) + return self.expression( + exp.With, comments=comments, expressions=expressions, recursive=recursive + ) def _parse_cte(self) -> exp.Expression: alias = self._parse_table_alias() @@ -1856,15 +1934,20 @@ class Parser(metaclass=_Parser): table = isinstance(this, exp.Table) while True: - lateral = self._parse_lateral() join = self._parse_join() - comma = None if table else self._match(TokenType.COMMA) - if lateral: - this.append("laterals", lateral) if join: this.append("joins", join) + + lateral = None + if not join: + lateral = self._parse_lateral() + if lateral: + this.append("laterals", lateral) + + comma = None if table else self._match(TokenType.COMMA) if comma: this.args["from"].append("expressions", self._parse_table()) + if not (lateral or join or comma): break @@ -1906,14 +1989,13 @@ class Parser(metaclass=_Parser): def _parse_match_recognize(self) -> t.Optional[exp.Expression]: if not self._match(TokenType.MATCH_RECOGNIZE): return None + self._match_l_paren() partition = self._parse_partition_by() order = self._parse_order() measures = ( - self._parse_alias(self._parse_conjunction()) - if self._match_text_seq("MEASURES") - else None + self._parse_csv(self._parse_expression) if self._match_text_seq("MEASURES") else None ) if self._match_text_seq("ONE", "ROW", "PER", "MATCH"): @@ -1967,8 +2049,17 @@ class Parser(metaclass=_Parser): pattern = None define = ( - self._parse_alias(self._parse_conjunction()) if self._match_text_seq("DEFINE") else None + self._parse_csv( + lambda: self.expression( + exp.Alias, + alias=self._parse_id_var(any_token=True), + this=self._match(TokenType.ALIAS) and self._parse_conjunction(), + ) + ) + if self._match_text_seq("DEFINE") + else None ) + self._match_r_paren() return self.expression( @@ -1980,6 +2071,7 @@ class Parser(metaclass=_Parser): after=after, pattern=pattern, define=define, + alias=self._parse_table_alias(), ) def _parse_lateral(self) -> t.Optional[exp.Expression]: @@ -2022,9 +2114,6 @@ class Parser(metaclass=_Parser): alias=table_alias, ) - if outer_apply or cross_apply: - return self.expression(exp.Join, this=expression, side=None if cross_apply else "LEFT") - return expression def _parse_join_side_and_kind( @@ -2037,11 +2126,26 @@ class Parser(metaclass=_Parser): ) def _parse_join(self, skip_join_token: bool = False) -> t.Optional[exp.Expression]: + index = self._index natural, side, kind = self._parse_join_side_and_kind() + hint = self._prev.text if self._match_texts(self.JOIN_HINTS) else None + join = self._match(TokenType.JOIN) - if not skip_join_token and not self._match(TokenType.JOIN): + if not skip_join_token and not join: + self._retreat(index) + kind = None + natural = None + side = None + + outer_apply = self._match_pair(TokenType.OUTER, TokenType.APPLY, False) + cross_apply = self._match_pair(TokenType.CROSS, TokenType.APPLY, False) + + if not skip_join_token and not join and not outer_apply and not cross_apply: return None + if outer_apply: + side = Token(TokenType.LEFT, "LEFT") + kwargs: t.Dict[ str, t.Optional[exp.Expression] | bool | str | t.List[t.Optional[exp.Expression]] ] = {"this": self._parse_table()} @@ -2052,6 +2156,8 @@ class Parser(metaclass=_Parser): kwargs["side"] = side.text if kind: kwargs["kind"] = kind.text + if hint: + kwargs["hint"] = hint if self._match(TokenType.ON): kwargs["on"] = self._parse_conjunction() @@ -2179,7 +2285,7 @@ class Parser(metaclass=_Parser): return None expressions = self._parse_wrapped_csv(self._parse_column) - ordinality = bool(self._match(TokenType.WITH) and self._match(TokenType.ORDINALITY)) + ordinality = self._match_pair(TokenType.WITH, TokenType.ORDINALITY) alias = self._parse_table_alias() if alias and self.unnest_column_only: @@ -2191,7 +2297,7 @@ class Parser(metaclass=_Parser): offset = None if self._match_pair(TokenType.WITH, TokenType.OFFSET): self._match(TokenType.ALIAS) - offset = self._parse_conjunction() + offset = self._parse_id_var() or exp.Identifier(this="offset") return self.expression( exp.Unnest, @@ -2294,6 +2400,9 @@ class Parser(metaclass=_Parser): else: expressions = self._parse_csv(lambda: self._parse_alias(self._parse_function())) + if not expressions: + self.raise_error("Failed to parse PIVOT's aggregation list") + if not self._match(TokenType.FOR): self.raise_error("Expecting FOR") @@ -2311,8 +2420,26 @@ class Parser(metaclass=_Parser): if not self._match_set((TokenType.PIVOT, TokenType.UNPIVOT), advance=False): pivot.set("alias", self._parse_table_alias()) + if not unpivot: + names = self._pivot_column_names(t.cast(t.List[exp.Expression], expressions)) + + columns: t.List[exp.Expression] = [] + for col in pivot.args["field"].expressions: + for name in names: + if self.PREFIXED_PIVOT_COLUMNS: + name = f"{name}_{col.alias_or_name}" if name else col.alias_or_name + else: + name = f"{col.alias_or_name}_{name}" if name else col.alias_or_name + + columns.append(exp.to_identifier(name, quoted=self.QUOTED_PIVOT_COLUMNS)) + + pivot.set("columns", columns) + return pivot + def _pivot_column_names(self, pivot_columns: t.List[exp.Expression]) -> t.List[str]: + return [agg.alias for agg in pivot_columns] + def _parse_where(self, skip_where_token: bool = False) -> t.Optional[exp.Expression]: if not skip_where_token and not self._match(TokenType.WHERE): return None @@ -2433,10 +2560,25 @@ class Parser(metaclass=_Parser): if self._match(TokenType.FETCH): direction = self._match_set((TokenType.FIRST, TokenType.NEXT)) direction = self._prev.text if direction else "FIRST" + count = self._parse_number() + percent = self._match(TokenType.PERCENT) + self._match_set((TokenType.ROW, TokenType.ROWS)) - self._match(TokenType.ONLY) - return self.expression(exp.Fetch, direction=direction, count=count) + + only = self._match(TokenType.ONLY) + with_ties = self._match_text_seq("WITH", "TIES") + + if only and with_ties: + self.raise_error("Cannot specify both ONLY and WITH TIES in FETCH clause") + + return self.expression( + exp.Fetch, + direction=direction, + count=count, + percent=percent, + with_ties=with_ties, + ) return this @@ -2493,7 +2635,11 @@ class Parser(metaclass=_Parser): negate = self._match(TokenType.NOT) if self._match_set(self.RANGE_PARSERS): - this = self.RANGE_PARSERS[self._prev.token_type](self, this) + expression = self.RANGE_PARSERS[self._prev.token_type](self, this) + if not expression: + return this + + this = expression elif self._match(TokenType.ISNULL): this = self.expression(exp.Is, this=this, expression=exp.Null()) @@ -2511,17 +2657,19 @@ class Parser(metaclass=_Parser): return this - def _parse_is(self, this: t.Optional[exp.Expression]) -> exp.Expression: + def _parse_is(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: + index = self._index - 1 negate = self._match(TokenType.NOT) if self._match(TokenType.DISTINCT_FROM): klass = exp.NullSafeEQ if negate else exp.NullSafeNEQ return self.expression(klass, this=this, expression=self._parse_expression()) - this = self.expression( - exp.Is, - this=this, - expression=self._parse_null() or self._parse_boolean(), - ) + expression = self._parse_null() or self._parse_boolean() + if not expression: + self._retreat(index) + return None + + this = self.expression(exp.Is, this=this, expression=expression) return self.expression(exp.Not, this=this) if negate else this def _parse_in(self, this: t.Optional[exp.Expression]) -> exp.Expression: @@ -2553,6 +2701,27 @@ class Parser(metaclass=_Parser): return this return self.expression(exp.Escape, this=this, expression=self._parse_string()) + def _parse_interval(self) -> t.Optional[exp.Expression]: + if not self._match(TokenType.INTERVAL): + return None + + this = self._parse_primary() or self._parse_term() + unit = self._parse_function() or self._parse_var() + + # Most dialects support, e.g., the form INTERVAL '5' day, thus we try to parse + # each INTERVAL expression into this canonical form so it's easy to transpile + if this and isinstance(this, exp.Literal): + if this.is_number: + this = exp.Literal.string(this.name) + + # Try to not clutter Snowflake's multi-part intervals like INTERVAL '1 day, 1 year' + parts = this.name.split() + if not unit and len(parts) <= 2: + this = exp.Literal.string(seq_get(parts, 0)) + unit = self.expression(exp.Var, this=seq_get(parts, 1)) + + return self.expression(exp.Interval, this=this, unit=unit) + def _parse_bitwise(self) -> t.Optional[exp.Expression]: this = self._parse_term() @@ -2588,20 +2757,24 @@ class Parser(metaclass=_Parser): return self._parse_at_time_zone(self._parse_type()) def _parse_type(self) -> t.Optional[exp.Expression]: - if self._match(TokenType.INTERVAL): - return self.expression(exp.Interval, this=self._parse_term(), unit=self._parse_field()) + interval = self._parse_interval() + if interval: + return interval index = self._index - type_token = self._parse_types(check_func=True) + data_type = self._parse_types(check_func=True) this = self._parse_column() - if type_token: + if data_type: if isinstance(this, exp.Literal): - return self.expression(exp.Cast, this=this, to=type_token) - if not type_token.args.get("expressions"): + parser = self.TYPE_LITERAL_PARSERS.get(data_type.this) + if parser: + return parser(self, this, data_type) + return self.expression(exp.Cast, this=this, to=data_type) + if not data_type.args.get("expressions"): self._retreat(index) return self._parse_column() - return type_token + return data_type return this @@ -2631,11 +2804,10 @@ class Parser(metaclass=_Parser): else: expressions = self._parse_csv(self._parse_conjunction) - if not expressions: + if not expressions or not self._match(TokenType.R_PAREN): self._retreat(index) return None - self._match_r_paren() maybe_func = True if self._match_pair(TokenType.L_BRACKET, TokenType.R_BRACKET): @@ -2720,15 +2892,14 @@ class Parser(metaclass=_Parser): ) def _parse_struct_kwargs(self) -> t.Optional[exp.Expression]: - if self._curr and self._curr.token_type in self.TYPE_TOKENS: - return self._parse_types() - + index = self._index this = self._parse_id_var() self._match(TokenType.COLON) data_type = self._parse_types() if not data_type: - return None + self._retreat(index) + return self._parse_types() return self.expression(exp.StructKwarg, this=this, expression=data_type) def _parse_at_time_zone(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: @@ -2825,6 +2996,7 @@ class Parser(metaclass=_Parser): this = self.expression(exp.Paren, this=self._parse_set_operations(this)) self._match_r_paren() + comments.extend(self._prev_comments) if this and comments: this.comments = comments @@ -2833,8 +3005,16 @@ class Parser(metaclass=_Parser): return None - def _parse_field(self, any_token: bool = False) -> t.Optional[exp.Expression]: - return self._parse_primary() or self._parse_function() or self._parse_id_var(any_token) + def _parse_field( + self, + any_token: bool = False, + tokens: t.Optional[t.Collection[TokenType]] = None, + ) -> t.Optional[exp.Expression]: + return ( + self._parse_primary() + or self._parse_function() + or self._parse_id_var(any_token=any_token, tokens=tokens) + ) def _parse_function( self, functions: t.Optional[t.Dict[str, t.Callable]] = None @@ -3079,12 +3259,10 @@ class Parser(metaclass=_Parser): return None def _parse_column_constraint(self) -> t.Optional[exp.Expression]: - this = self._parse_references() - if this: - return this - if self._match(TokenType.CONSTRAINT): this = self._parse_id_var() + else: + this = None if self._match_texts(self.CONSTRAINT_PARSERS): return self.expression( @@ -3164,8 +3342,8 @@ class Parser(metaclass=_Parser): return options - def _parse_references(self) -> t.Optional[exp.Expression]: - if not self._match(TokenType.REFERENCES): + def _parse_references(self, match=True) -> t.Optional[exp.Expression]: + if match and not self._match(TokenType.REFERENCES): return None expressions = None @@ -3234,7 +3412,7 @@ class Parser(metaclass=_Parser): elif not this or this.name.upper() == "ARRAY": this = self.expression(exp.Array, expressions=expressions) else: - expressions = apply_index_offset(expressions, -self.index_offset) + expressions = apply_index_offset(this, expressions, -self.index_offset) this = self.expression(exp.Bracket, this=this, expressions=expressions) if not self._match(TokenType.R_BRACKET) and bracket_kind == TokenType.L_BRACKET: @@ -3279,7 +3457,13 @@ class Parser(metaclass=_Parser): self.validate_expression(this, args) self._match_r_paren() else: + index = self._index - 1 condition = self._parse_conjunction() + + if not condition: + self._retreat(index) + return None + self._match(TokenType.THEN) true = self._parse_conjunction() false = self._parse_conjunction() if self._match(TokenType.ELSE) else None @@ -3591,14 +3775,24 @@ class Parser(metaclass=_Parser): # bigquery select from window x AS (partition by ...) if alias: + over = None self._match(TokenType.ALIAS) - elif not self._match(TokenType.OVER): + elif not self._match_set(self.WINDOW_BEFORE_PAREN_TOKENS): return this + else: + over = self._prev.text.upper() if not self._match(TokenType.L_PAREN): - return self.expression(exp.Window, this=this, alias=self._parse_id_var(False)) + return self.expression( + exp.Window, this=this, alias=self._parse_id_var(False), over=over + ) window_alias = self._parse_id_var(any_token=False, tokens=self.WINDOW_ALIAS_TOKENS) + + first = self._match(TokenType.FIRST) + if self._match_text_seq("LAST"): + first = False + partition = self._parse_partition_by() order = self._parse_order() kind = self._match_set((TokenType.ROWS, TokenType.RANGE)) and self._prev.text @@ -3629,6 +3823,8 @@ class Parser(metaclass=_Parser): order=order, spec=spec, alias=window_alias, + over=over, + first=first, ) def _parse_window_spec(self) -> t.Dict[str, t.Optional[str | exp.Expression]]: @@ -3886,7 +4082,10 @@ class Parser(metaclass=_Parser): return expression def _parse_drop_column(self) -> t.Optional[exp.Expression]: - return self._match(TokenType.DROP) and self._parse_drop(default_kind="COLUMN") + drop = self._match(TokenType.DROP) and self._parse_drop() + if drop and not isinstance(drop, exp.Command): + drop.set("kind", drop.args.get("kind", "COLUMN")) + return drop # https://docs.aws.amazon.com/athena/latest/ug/alter-table-drop-partition.html def _parse_drop_partition(self, exists: t.Optional[bool] = None) -> exp.Expression: @@ -4010,7 +4209,7 @@ class Parser(metaclass=_Parser): if self._match(TokenType.INSERT): _this = self._parse_star() if _this: - then = self.expression(exp.Insert, this=_this) + then: t.Optional[exp.Expression] = self.expression(exp.Insert, this=_this) else: then = self.expression( exp.Insert, @@ -4239,5 +4438,8 @@ class Parser(metaclass=_Parser): break parent = parent.parent else: - column.replace(dot_or_id) + if column is node: + node = dot_or_id + else: + column.replace(dot_or_id) return node |