diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-01-23 05:06:14 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-01-23 05:06:14 +0000 |
commit | 38e6461a8afbd7cb83709ddb998f03d40ba87755 (patch) | |
tree | 64b68a893a3b946111b9cab69503f83ca233c335 /sqlglot/parser.py | |
parent | Releasing debian version 20.4.0-1. (diff) | |
download | sqlglot-38e6461a8afbd7cb83709ddb998f03d40ba87755.tar.xz sqlglot-38e6461a8afbd7cb83709ddb998f03d40ba87755.zip |
Merging upstream version 20.9.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/parser.py')
-rw-r--r-- | sqlglot/parser.py | 315 |
1 files changed, 229 insertions, 86 deletions
diff --git a/sqlglot/parser.py b/sqlglot/parser.py index 311c43d..790ee0d 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -12,6 +12,8 @@ from sqlglot.tokens import Token, Tokenizer, TokenType from sqlglot.trie import TrieResult, in_trie, new_trie if t.TYPE_CHECKING: + from typing_extensions import Literal + from sqlglot._typing import E from sqlglot.dialects.dialect import Dialect, DialectType @@ -193,6 +195,7 @@ class Parser(metaclass=_Parser): TokenType.DATETIME, TokenType.DATETIME64, TokenType.DATE, + TokenType.DATE32, TokenType.INT4RANGE, TokenType.INT4MULTIRANGE, TokenType.INT8RANGE, @@ -232,6 +235,8 @@ class Parser(metaclass=_Parser): TokenType.INET, TokenType.IPADDRESS, TokenType.IPPREFIX, + TokenType.IPV4, + TokenType.IPV6, TokenType.UNKNOWN, TokenType.NULL, *ENUM_TYPE_TOKENS, @@ -669,6 +674,7 @@ class Parser(metaclass=_Parser): PROPERTY_PARSERS: t.Dict[str, t.Callable] = { "ALGORITHM": lambda self: self._parse_property_assignment(exp.AlgorithmProperty), + "AUTO": lambda self: self._parse_auto_property(), "AUTO_INCREMENT": lambda self: self._parse_property_assignment(exp.AutoIncrementProperty), "BLOCKCOMPRESSION": lambda self: self._parse_blockcompression(), "CHARSET": lambda self, **kwargs: self._parse_character_set(**kwargs), @@ -680,6 +686,7 @@ class Parser(metaclass=_Parser): exp.CollateProperty, **kwargs ), "COMMENT": lambda self: self._parse_property_assignment(exp.SchemaCommentProperty), + "CONTAINS": lambda self: self._parse_contains_property(), "COPY": lambda self: self._parse_copy_property(), "DATABLOCKSIZE": lambda self, **kwargs: self._parse_datablocksize(**kwargs), "DEFINER": lambda self: self._parse_definer(), @@ -710,6 +717,7 @@ class Parser(metaclass=_Parser): "LOG": lambda self, **kwargs: self._parse_log(**kwargs), "MATERIALIZED": lambda self: self.expression(exp.MaterializedProperty), "MERGEBLOCKRATIO": lambda self, **kwargs: self._parse_mergeblockratio(**kwargs), + "MODIFIES": lambda self: self._parse_modifies_property(), "MULTISET": lambda self: self.expression(exp.SetProperty, multi=True), "NO": lambda self: self._parse_no_property(), "ON": lambda self: self._parse_on_property(), @@ -721,6 +729,7 @@ class Parser(metaclass=_Parser): "PARTITIONED_BY": lambda self: self._parse_partitioned_by(), "PRIMARY KEY": lambda self: self._parse_primary_key(in_props=True), "RANGE": lambda self: self._parse_dict_range(this="RANGE"), + "READS": lambda self: self._parse_reads_property(), "REMOTE": lambda self: self._parse_remote_with_connection(), "RETURNS": lambda self: self._parse_returns(), "ROW": lambda self: self._parse_row(), @@ -841,6 +850,7 @@ class Parser(metaclass=_Parser): "DECODE": lambda self: self._parse_decode(), "EXTRACT": lambda self: self._parse_extract(), "JSON_OBJECT": lambda self: self._parse_json_object(), + "JSON_OBJECTAGG": lambda self: self._parse_json_object(agg=True), "JSON_TABLE": lambda self: self._parse_json_table(), "MATCH": lambda self: self._parse_match_against(), "OPENJSON": lambda self: self._parse_open_json(), @@ -925,6 +935,8 @@ class Parser(metaclass=_Parser): WINDOW_BEFORE_PAREN_TOKENS = {TokenType.OVER} WINDOW_SIDES = {"FOLLOWING", "PRECEDING"} + JSON_KEY_VALUE_SEPARATOR_TOKENS = {TokenType.COLON, TokenType.COMMA, TokenType.IS} + FETCH_TOKENS = ID_VAR_TOKENS - {TokenType.ROW, TokenType.ROWS, TokenType.PERCENT} ADD_CONSTRAINT_TOKENS = {TokenType.CONSTRAINT, TokenType.PRIMARY_KEY, TokenType.FOREIGN_KEY} @@ -954,6 +966,9 @@ class Parser(metaclass=_Parser): # Whether the TRIM function expects the characters to trim as its first argument TRIM_PATTERN_FIRST = False + # Whether or not string aliases are supported `SELECT COUNT(*) 'count'` + STRING_ALIASES = False + # Whether query modifiers such as LIMIT are attached to the UNION node (vs its right operand) MODIFIERS_ATTACHED_TO_UNION = True UNION_MODIFIERS = {"order", "limit", "offset"} @@ -1193,7 +1208,9 @@ class Parser(metaclass=_Parser): self._advance(index - self._index) def _parse_command(self) -> exp.Command: - return self.expression(exp.Command, this=self._prev.text, expression=self._parse_string()) + return self.expression( + exp.Command, this=self._prev.text.upper(), expression=self._parse_string() + ) def _parse_comment(self, allow_exists: bool = True) -> exp.Expression: start = self._prev @@ -1353,26 +1370,27 @@ class Parser(metaclass=_Parser): # exp.Properties.Location.POST_SCHEMA ("schema" here is the UDF's type signature) extend_props(self._parse_properties()) - self._match(TokenType.ALIAS) - - if self._match(TokenType.COMMAND): - expression = self._parse_as_command(self._prev) - else: - begin = self._match(TokenType.BEGIN) - return_ = self._match_text_seq("RETURN") + expression = self._match(TokenType.ALIAS) and self._parse_heredoc() - if self._match(TokenType.STRING, advance=False): - # Takes care of BigQuery's JavaScript UDF definitions that end in an OPTIONS property - # # https://cloud.google.com/bigquery/docs/reference/standard-sql/data-definition-language#create_function_statement - expression = self._parse_string() - extend_props(self._parse_properties()) + if not expression: + if self._match(TokenType.COMMAND): + expression = self._parse_as_command(self._prev) else: - expression = self._parse_statement() + begin = self._match(TokenType.BEGIN) + return_ = self._match_text_seq("RETURN") + + if self._match(TokenType.STRING, advance=False): + # Takes care of BigQuery's JavaScript UDF definitions that end in an OPTIONS property + # # https://cloud.google.com/bigquery/docs/reference/standard-sql/data-definition-language#create_function_statement + expression = self._parse_string() + extend_props(self._parse_properties()) + else: + expression = self._parse_statement() - end = self._match_text_seq("END") + end = self._match_text_seq("END") - if return_: - expression = self.expression(exp.Return, this=expression) + if return_: + expression = self.expression(exp.Return, this=expression) elif create_token.token_type == TokenType.INDEX: this = self._parse_index(index=self._parse_id_var()) elif create_token.token_type in self.DB_CREATABLES: @@ -1426,7 +1444,7 @@ class Parser(metaclass=_Parser): exp.Create, comments=comments, this=this, - kind=create_token.text, + kind=create_token.text.upper(), replace=replace, unique=unique, expression=expression, @@ -1849,9 +1867,21 @@ class Parser(metaclass=_Parser): return self.expression(exp.WithDataProperty, no=no, statistics=statistics) - def _parse_no_property(self) -> t.Optional[exp.NoPrimaryIndexProperty]: + def _parse_contains_property(self) -> t.Optional[exp.SqlReadWriteProperty]: + if self._match_text_seq("SQL"): + return self.expression(exp.SqlReadWriteProperty, this="CONTAINS SQL") + return None + + def _parse_modifies_property(self) -> t.Optional[exp.SqlReadWriteProperty]: + if self._match_text_seq("SQL", "DATA"): + return self.expression(exp.SqlReadWriteProperty, this="MODIFIES SQL DATA") + return None + + def _parse_no_property(self) -> t.Optional[exp.Expression]: if self._match_text_seq("PRIMARY", "INDEX"): return exp.NoPrimaryIndexProperty() + if self._match_text_seq("SQL"): + return self.expression(exp.SqlReadWriteProperty, this="NO SQL") return None def _parse_on_property(self) -> t.Optional[exp.Expression]: @@ -1861,6 +1891,11 @@ class Parser(metaclass=_Parser): return exp.OnCommitProperty(delete=True) return self.expression(exp.OnProperty, this=self._parse_schema(self._parse_id_var())) + def _parse_reads_property(self) -> t.Optional[exp.SqlReadWriteProperty]: + if self._match_text_seq("SQL", "DATA"): + return self.expression(exp.SqlReadWriteProperty, this="READS SQL DATA") + return None + def _parse_distkey(self) -> exp.DistKeyProperty: return self.expression(exp.DistKeyProperty, this=self._parse_wrapped(self._parse_id_var)) @@ -1920,10 +1955,13 @@ class Parser(metaclass=_Parser): def _parse_describe(self) -> exp.Describe: kind = self._match_set(self.CREATABLES) and self._prev.text + extended = self._match_text_seq("EXTENDED") this = self._parse_table(schema=True) properties = self._parse_properties() expressions = properties.expressions if properties else None - return self.expression(exp.Describe, this=this, kind=kind, expressions=expressions) + return self.expression( + exp.Describe, this=this, extended=extended, kind=kind, expressions=expressions + ) def _parse_insert(self) -> exp.Insert: comments = ensure_list(self._prev_comments) @@ -2164,13 +2202,13 @@ class Parser(metaclass=_Parser): def _parse_value(self) -> exp.Tuple: if self._match(TokenType.L_PAREN): - expressions = self._parse_csv(self._parse_conjunction) + expressions = self._parse_csv(self._parse_expression) self._match_r_paren() return self.expression(exp.Tuple, expressions=expressions) # In presto we can have VALUES 1, 2 which results in 1 column & 2 rows. # https://prestodb.io/docs/current/sql/values.html - return self.expression(exp.Tuple, expressions=[self._parse_conjunction()]) + return self.expression(exp.Tuple, expressions=[self._parse_expression()]) def _parse_projections(self) -> t.List[exp.Expression]: return self._parse_expressions() @@ -2212,7 +2250,7 @@ class Parser(metaclass=_Parser): kind = ( self._match(TokenType.ALIAS) and self._match_texts(("STRUCT", "VALUE")) - and self._prev.text + and self._prev.text.upper() ) if distinct: @@ -2261,7 +2299,7 @@ class Parser(metaclass=_Parser): if table else self._parse_select(nested=True, parse_set_operation=False) ) - this = self._parse_set_operations(self._parse_query_modifiers(this)) + this = self._parse_query_modifiers(self._parse_set_operations(this)) self._match_r_paren() @@ -2304,7 +2342,7 @@ class Parser(metaclass=_Parser): ) def _parse_cte(self) -> exp.CTE: - alias = self._parse_table_alias() + alias = self._parse_table_alias(self.ID_VAR_TOKENS) if not alias or not alias.this: self.raise_error("Expected CTE to have alias") @@ -2490,13 +2528,14 @@ class Parser(metaclass=_Parser): ) def _parse_lateral(self) -> t.Optional[exp.Lateral]: - outer_apply = self._match_pair(TokenType.OUTER, TokenType.APPLY) cross_apply = self._match_pair(TokenType.CROSS, TokenType.APPLY) + if not cross_apply and self._match_pair(TokenType.OUTER, TokenType.APPLY): + cross_apply = False - if outer_apply or cross_apply: + if cross_apply is not None: this = self._parse_select(table=True) view = None - outer = not cross_apply + outer = None elif self._match(TokenType.LATERAL): this = self._parse_select(table=True) view = self._match(TokenType.VIEW) @@ -2529,7 +2568,14 @@ class Parser(metaclass=_Parser): else: table_alias = self._parse_table_alias() - return self.expression(exp.Lateral, this=this, view=view, outer=outer, alias=table_alias) + return self.expression( + exp.Lateral, + this=this, + view=view, + outer=outer, + alias=table_alias, + cross_apply=cross_apply, + ) def _parse_join_parts( self, @@ -2563,9 +2609,6 @@ class Parser(metaclass=_Parser): if not skip_join_token and not join and not outer_apply and not cross_apply: return None - if outer_apply: - side = Token(TokenType.LEFT, "LEFT") - kwargs: t.Dict[str, t.Any] = {"this": self._parse_table(parse_bracket=parse_bracket)} if method: @@ -2755,8 +2798,10 @@ class Parser(metaclass=_Parser): if alias: this.set("alias", alias) - if self._match_text_seq("AT"): - this.set("index", self._parse_id_var()) + if isinstance(this, exp.Table) and self._match_text_seq("AT"): + return self.expression( + exp.AtIndex, this=this.to_column(copy=False), expression=self._parse_id_var() + ) this.set("hints", self._parse_table_hints()) @@ -2865,15 +2910,10 @@ class Parser(metaclass=_Parser): bucket_denominator = None bucket_field = None percent = None - rows = None size = None seed = None - kind = ( - self._prev.text if self._prev.token_type == TokenType.TABLE_SAMPLE else "USING SAMPLE" - ) - method = self._parse_var(tokens=(TokenType.ROW,)) - + method = self._parse_var(tokens=(TokenType.ROW,), upper=True) matched_l_paren = self._match(TokenType.L_PAREN) if self.TABLESAMPLE_CSV: @@ -2895,16 +2935,16 @@ class Parser(metaclass=_Parser): bucket_field = self._parse_field() elif self._match_set((TokenType.PERCENT, TokenType.MOD)): percent = num - elif self._match(TokenType.ROWS): - rows = num - elif num: + elif self._match(TokenType.ROWS) or not self.dialect.TABLESAMPLE_SIZE_IS_PERCENT: size = num + else: + percent = num if matched_l_paren: self._match_r_paren() if self._match(TokenType.L_PAREN): - method = self._parse_var() + method = self._parse_var(upper=True) seed = self._match(TokenType.COMMA) and self._parse_number() self._match_r_paren() elif self._match_texts(("SEED", "REPEATABLE")): @@ -2918,10 +2958,8 @@ class Parser(metaclass=_Parser): bucket_denominator=bucket_denominator, bucket_field=bucket_field, percent=percent, - rows=rows, size=size, seed=seed, - kind=kind, ) def _parse_pivots(self) -> t.Optional[t.List[exp.Pivot]]: @@ -2946,6 +2984,27 @@ class Parser(metaclass=_Parser): exp.Pivot, this=this, expressions=expressions, using=using, group=group ) + def _parse_pivot_in(self) -> exp.In: + def _parse_aliased_expression() -> t.Optional[exp.Expression]: + this = self._parse_conjunction() + + self._match(TokenType.ALIAS) + alias = self._parse_field() + if alias: + return self.expression(exp.PivotAlias, this=this, alias=alias) + + return this + + value = self._parse_column() + + if not self._match_pair(TokenType.IN, TokenType.L_PAREN): + self.raise_error("Expecting IN (") + + aliased_expressions = self._parse_csv(_parse_aliased_expression) + + self._match_r_paren() + return self.expression(exp.In, this=value, expressions=aliased_expressions) + def _parse_pivot(self) -> t.Optional[exp.Pivot]: index = self._index include_nulls = None @@ -2964,7 +3023,6 @@ class Parser(metaclass=_Parser): return None expressions = [] - field = None if not self._match(TokenType.L_PAREN): self._retreat(index) @@ -2981,12 +3039,7 @@ class Parser(metaclass=_Parser): if not self._match(TokenType.FOR): self.raise_error("Expecting FOR") - value = self._parse_column() - - if not self._match(TokenType.IN): - self.raise_error("Expecting IN") - - field = self._parse_in(value, alias=True) + field = self._parse_pivot_in() self._match_r_paren() @@ -3132,14 +3185,19 @@ class Parser(metaclass=_Parser): def _parse_order( self, this: t.Optional[exp.Expression] = None, skip_order_token: bool = False ) -> t.Optional[exp.Expression]: + siblings = None if not skip_order_token and not self._match(TokenType.ORDER_BY): - return this + if not self._match(TokenType.ORDER_SIBLINGS_BY): + return this + + siblings = True return self.expression( exp.Order, this=this, expressions=self._parse_csv(self._parse_ordered), interpolate=self._parse_interpolate(), + siblings=siblings, ) def _parse_sort(self, exp_class: t.Type[E], token: TokenType) -> t.Optional[E]: @@ -3213,7 +3271,7 @@ class Parser(metaclass=_Parser): if self._match(TokenType.FETCH): direction = self._match_set((TokenType.FIRST, TokenType.NEXT)) - direction = self._prev.text if direction else "FIRST" + direction = self._prev.text.upper() if direction else "FIRST" count = self._parse_field(tokens=self.FETCH_TOKENS) percent = self._match(TokenType.PERCENT) @@ -3398,10 +3456,10 @@ class Parser(metaclass=_Parser): return this return self.expression(exp.Escape, this=this, expression=self._parse_string()) - def _parse_interval(self) -> t.Optional[exp.Interval]: + def _parse_interval(self, match_interval: bool = True) -> t.Optional[exp.Interval]: index = self._index - if not self._match(TokenType.INTERVAL): + if not self._match(TokenType.INTERVAL) and match_interval: return None if self._match(TokenType.STRING, advance=False): @@ -3409,11 +3467,19 @@ class Parser(metaclass=_Parser): else: this = self._parse_term() - if not this: + if not this or ( + isinstance(this, exp.Column) + and not this.table + and not this.this.quoted + and this.name.upper() == "IS" + ): self._retreat(index) return None - unit = self._parse_function() or self._parse_var(any_token=True) + unit = self._parse_function() or ( + not self._match(TokenType.ALIAS, advance=False) + and self._parse_var(any_token=True, upper=True) + ) # Most dialects support, e.g., the form INTERVAL '5' day, thus we try to parse # each INTERVAL expression into this canonical form so it's easy to transpile @@ -3429,7 +3495,7 @@ class Parser(metaclass=_Parser): self._retreat(self._index - 1) this = exp.Literal.string(parts[0]) - unit = self.expression(exp.Var, this=parts[1]) + unit = self.expression(exp.Var, this=parts[1].upper()) return self.expression(exp.Interval, this=this, unit=unit) @@ -3489,6 +3555,12 @@ class Parser(metaclass=_Parser): def _parse_type(self, parse_interval: bool = True) -> t.Optional[exp.Expression]: interval = parse_interval and self._parse_interval() if interval: + # Convert INTERVAL 'val_1' unit_1 ... 'val_n' unit_n into a sum of intervals + while self._match_set((TokenType.STRING, TokenType.NUMBER), advance=False): + interval = self.expression( # type: ignore + exp.Add, this=interval, expression=self._parse_interval(match_interval=False) + ) + return interval index = self._index @@ -3552,10 +3624,10 @@ class Parser(metaclass=_Parser): type_token = self._prev.token_type if type_token == TokenType.PSEUDO_TYPE: - return self.expression(exp.PseudoType, this=self._prev.text) + return self.expression(exp.PseudoType, this=self._prev.text.upper()) if type_token == TokenType.OBJECT_IDENTIFIER: - return self.expression(exp.ObjectIdentifier, this=self._prev.text) + return self.expression(exp.ObjectIdentifier, this=self._prev.text.upper()) nested = type_token in self.NESTED_TYPE_TOKENS is_struct = type_token in self.STRUCT_TYPE_TOKENS @@ -3587,7 +3659,7 @@ class Parser(metaclass=_Parser): if nested and self._match(TokenType.LT): if is_struct: - expressions = self._parse_csv(self._parse_struct_types) + expressions = self._parse_csv(lambda: self._parse_struct_types(type_required=True)) else: expressions = self._parse_csv( lambda: self._parse_types( @@ -3662,10 +3734,19 @@ class Parser(metaclass=_Parser): return this - def _parse_struct_types(self) -> t.Optional[exp.Expression]: + def _parse_struct_types(self, type_required: bool = False) -> t.Optional[exp.Expression]: + index = self._index this = self._parse_type(parse_interval=False) or self._parse_id_var() self._match(TokenType.COLON) - return self._parse_column_def(this) + column_def = self._parse_column_def(this) + + if type_required and ( + (isinstance(this, exp.Column) and this.this is column_def) or this is column_def + ): + self._retreat(index) + return self._parse_types() + + return column_def def _parse_at_time_zone(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: if not self._match_text_seq("AT", "TIME", "ZONE"): @@ -4025,6 +4106,12 @@ class Parser(metaclass=_Parser): return exp.AutoIncrementColumnConstraint() + def _parse_auto_property(self) -> t.Optional[exp.AutoRefreshProperty]: + if not self._match_text_seq("REFRESH"): + self._retreat(self._index - 1) + return None + return self.expression(exp.AutoRefreshProperty, this=self._parse_var(upper=True)) + def _parse_compress(self) -> exp.CompressColumnConstraint: if self._match(TokenType.L_PAREN, advance=False): return self.expression( @@ -4230,8 +4317,10 @@ class Parser(metaclass=_Parser): def _parse_primary_key_part(self) -> t.Optional[exp.Expression]: return self._parse_field() - def _parse_period_for_system_time(self) -> exp.PeriodForSystemTimeConstraint: - self._match(TokenType.TIMESTAMP_SNAPSHOT) + def _parse_period_for_system_time(self) -> t.Optional[exp.PeriodForSystemTimeConstraint]: + if not self._match(TokenType.TIMESTAMP_SNAPSHOT): + self._retreat(self._index - 1) + return None id_vars = self._parse_wrapped_id_vars() return self.expression( @@ -4257,22 +4346,17 @@ class Parser(metaclass=_Parser): options = self._parse_key_constraint_options() return self.expression(exp.PrimaryKey, expressions=expressions, options=options) + def _parse_bracket_key_value(self, is_map: bool = False) -> t.Optional[exp.Expression]: + return self._parse_slice(self._parse_alias(self._parse_conjunction(), explicit=True)) + def _parse_bracket(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: if not self._match_set((TokenType.L_BRACKET, TokenType.L_BRACE)): return this bracket_kind = self._prev.token_type - - if self._match(TokenType.COLON): - expressions: t.List[exp.Expression] = [ - self.expression(exp.Slice, expression=self._parse_conjunction()) - ] - else: - expressions = self._parse_csv( - lambda: self._parse_slice( - self._parse_alias(self._parse_conjunction(), explicit=True) - ) - ) + expressions = self._parse_csv( + lambda: self._parse_bracket_key_value(is_map=bracket_kind == TokenType.L_BRACE) + ) if not self._match(TokenType.R_BRACKET) and bracket_kind == TokenType.L_BRACKET: self.raise_error("Expected ]") @@ -4313,7 +4397,10 @@ class Parser(metaclass=_Parser): default = self._parse_conjunction() if not self._match(TokenType.END): - self.raise_error("Expected END after CASE", self._prev) + if isinstance(default, exp.Interval) and default.this.sql().upper() == "END": + default = exp.column("interval") + else: + self.raise_error("Expected END after CASE", self._prev) return self._parse_window( self.expression(exp.Case, comments=comments, this=expression, ifs=ifs, default=default) @@ -4514,7 +4601,7 @@ class Parser(metaclass=_Parser): def _parse_json_key_value(self) -> t.Optional[exp.JSONKeyValue]: self._match_text_seq("KEY") key = self._parse_column() - self._match_set((TokenType.COLON, TokenType.COMMA)) + self._match_set(self.JSON_KEY_VALUE_SEPARATOR_TOKENS) self._match_text_seq("VALUE") value = self._parse_bitwise() @@ -4536,7 +4623,15 @@ class Parser(metaclass=_Parser): return None - def _parse_json_object(self) -> exp.JSONObject: + @t.overload + def _parse_json_object(self, agg: Literal[False]) -> exp.JSONObject: + ... + + @t.overload + def _parse_json_object(self, agg: Literal[True]) -> exp.JSONObjectAgg: + ... + + def _parse_json_object(self, agg=False): star = self._parse_star() expressions = ( [star] @@ -4559,7 +4654,7 @@ class Parser(metaclass=_Parser): encoding = self._match_text_seq("ENCODING") and self._parse_var() return self.expression( - exp.JSONObject, + exp.JSONObjectAgg if agg else exp.JSONObject, expressions=expressions, null_handling=null_handling, unique_keys=unique_keys, @@ -4873,10 +4968,17 @@ class Parser(metaclass=_Parser): self._match_r_paren(aliases) return aliases - alias = self._parse_id_var(any_token) + alias = self._parse_id_var(any_token) or ( + self.STRING_ALIASES and self._parse_string_as_identifier() + ) if alias: - return self.expression(exp.Alias, comments=comments, this=this, alias=alias) + this = self.expression(exp.Alias, comments=comments, this=this, alias=alias) + + # Moves the comment next to the alias in `expr /* comment */ AS alias` + if not this.comments and this.this.comments: + this.comments = this.this.comments + this.this.comments = None return this @@ -4915,14 +5017,19 @@ class Parser(metaclass=_Parser): return self._parse_placeholder() def _parse_var( - self, any_token: bool = False, tokens: t.Optional[t.Collection[TokenType]] = None + self, + any_token: bool = False, + tokens: t.Optional[t.Collection[TokenType]] = None, + upper: bool = False, ) -> t.Optional[exp.Expression]: if ( (any_token and self._advance_any()) or self._match(TokenType.VAR) or (self._match_set(tokens) if tokens else False) ): - return self.expression(exp.Var, this=self._prev.text) + return self.expression( + exp.Var, this=self._prev.text.upper() if upper else self._prev.text + ) return self._parse_placeholder() def _advance_any(self, ignore_reserved: bool = False) -> t.Optional[Token]: @@ -5418,6 +5525,42 @@ class Parser(metaclass=_Parser): condition=condition, ) + def _parse_heredoc(self) -> t.Optional[exp.Heredoc]: + if self._match(TokenType.HEREDOC_STRING): + return self.expression(exp.Heredoc, this=self._prev.text) + + if not self._match_text_seq("$"): + return None + + tags = ["$"] + tag_text = None + + if self._is_connected(): + self._advance() + tags.append(self._prev.text.upper()) + else: + self.raise_error("No closing $ found") + + if tags[-1] != "$": + if self._is_connected() and self._match_text_seq("$"): + tag_text = tags[-1] + tags.append("$") + else: + self.raise_error("No closing $ found") + + heredoc_start = self._curr + + while self._curr: + if self._match_text_seq(*tags, advance=False): + this = self._find_sql(heredoc_start, self._prev) + self._advance(len(tags)) + return self.expression(exp.Heredoc, this=this, tag=tag_text) + + self._advance() + + self.raise_error(f"No closing {''.join(tags)} found") + return None + def _find_parser( self, parsers: t.Dict[str, t.Callable], trie: t.Dict ) -> t.Optional[t.Callable]: |