From 67578a7602a5be7eb51f324086c8d49bcf8b7498 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Fri, 16 Jun 2023 11:41:18 +0200 Subject: Merging upstream version 16.2.1. Signed-off-by: Daniel Baumann --- sqlglot/parser.py | 682 ++++++++++++++++++++++++++++-------------------------- 1 file changed, 350 insertions(+), 332 deletions(-) (limited to 'sqlglot/parser.py') diff --git a/sqlglot/parser.py b/sqlglot/parser.py index 96bd6e3..d6888c7 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -6,7 +6,8 @@ from collections import defaultdict from sqlglot import exp from sqlglot.errors import ErrorLevel, ParseError, concat_messages, merge_errors -from sqlglot.helper import apply_index_offset, ensure_collection, ensure_list, seq_get +from sqlglot.helper import apply_index_offset, ensure_list, seq_get +from sqlglot.time import format_time from sqlglot.tokens import Token, Tokenizer, TokenType from sqlglot.trie import in_trie, new_trie @@ -25,13 +26,14 @@ def parse_var_map(args: t.List) -> exp.StarMap | exp.VarMap: for i in range(0, len(args), 2): keys.append(args[i]) values.append(args[i + 1]) + return exp.VarMap( keys=exp.Array(expressions=keys), values=exp.Array(expressions=values), ) -def parse_like(args: t.List) -> exp.Expression: +def parse_like(args: t.List) -> exp.Escape | exp.Like: like = exp.Like(this=seq_get(args, 1), expression=seq_get(args, 0)) return exp.Escape(this=like, expression=seq_get(args, 2)) if len(args) > 2 else like @@ -47,33 +49,26 @@ def binary_range_parser( class _Parser(type): def __new__(cls, clsname, bases, attrs): klass = super().__new__(cls, clsname, bases, attrs) - klass._show_trie = new_trie(key.split(" ") for key in klass.SHOW_PARSERS) - klass._set_trie = new_trie(key.split(" ") for key in klass.SET_PARSERS) + + klass.SHOW_TRIE = new_trie(key.split(" ") for key in klass.SHOW_PARSERS) + klass.SET_TRIE = new_trie(key.split(" ") for key in klass.SET_PARSERS) return klass class Parser(metaclass=_Parser): """ - Parser consumes a list of tokens produced by the `sqlglot.tokens.Tokenizer` and produces - a parsed syntax tree. + Parser consumes a list of tokens produced by the Tokenizer and produces a parsed syntax tree. Args: - error_level: the desired error level. + error_level: The desired error level. Default: ErrorLevel.IMMEDIATE - error_message_context: determines the amount of context to capture from a + error_message_context: Determines the amount of context to capture from a query string when displaying the error message (in number of characters). - Default: 50. - index_offset: Index offset for arrays eg ARRAY[0] vs ARRAY[1] as the head of a list. - Default: 0 - alias_post_tablesample: If the table alias comes after tablesample. - Default: False + Default: 100 max_errors: Maximum number of error messages to include in a raised ParseError. This is only relevant if error_level is ErrorLevel.RAISE. Default: 3 - null_ordering: Indicates the default null ordering method to use if not explicitly set. - Options are "nulls_are_small", "nulls_are_large", "nulls_are_last". - Default: "nulls_are_small" """ FUNCTIONS: t.Dict[str, t.Callable] = { @@ -83,7 +78,6 @@ class Parser(metaclass=_Parser): to=exp.DataType(this=exp.DataType.Type.TEXT), ), "GLOB": lambda args: exp.Glob(this=seq_get(args, 1), expression=seq_get(args, 0)), - "IFNULL": exp.Coalesce.from_arg_list, "LIKE": parse_like, "TIME_TO_TIME_STR": lambda args: exp.Cast( this=seq_get(args, 0), @@ -108,8 +102,6 @@ class Parser(metaclass=_Parser): TokenType.CURRENT_USER: exp.CurrentUser, } - JOIN_HINTS: t.Set[str] = set() - NESTED_TYPE_TOKENS = { TokenType.ARRAY, TokenType.MAP, @@ -117,6 +109,10 @@ class Parser(metaclass=_Parser): TokenType.STRUCT, } + ENUM_TYPE_TOKENS = { + TokenType.ENUM, + } + TYPE_TOKENS = { TokenType.BIT, TokenType.BOOLEAN, @@ -188,6 +184,7 @@ class Parser(metaclass=_Parser): TokenType.VARIANT, TokenType.OBJECT, TokenType.INET, + TokenType.ENUM, *NESTED_TYPE_TOKENS, } @@ -198,7 +195,10 @@ class Parser(metaclass=_Parser): TokenType.SOME: exp.Any, } - RESERVED_KEYWORDS = {*Tokenizer.SINGLE_TOKENS.values(), TokenType.SELECT} + RESERVED_KEYWORDS = { + *Tokenizer.SINGLE_TOKENS.values(), + TokenType.SELECT, + } DB_CREATABLES = { TokenType.DATABASE, @@ -216,6 +216,7 @@ class Parser(metaclass=_Parser): *DB_CREATABLES, } + # Tokens that can represent identifiers ID_VAR_TOKENS = { TokenType.VAR, TokenType.ANTI, @@ -224,6 +225,7 @@ class Parser(metaclass=_Parser): TokenType.AUTO_INCREMENT, TokenType.BEGIN, TokenType.CACHE, + TokenType.CASE, TokenType.COLLATE, TokenType.COMMAND, TokenType.COMMENT, @@ -274,6 +276,7 @@ class Parser(metaclass=_Parser): TokenType.TRUE, TokenType.UNIQUE, TokenType.UNPIVOT, + TokenType.UPDATE, TokenType.VOLATILE, TokenType.WINDOW, *CREATABLES, @@ -409,6 +412,8 @@ class Parser(metaclass=_Parser): TokenType.ANTI, } + JOIN_HINTS: t.Set[str] = set() + LAMBDAS = { TokenType.ARROW: lambda self, expressions: self.expression( exp.Lambda, @@ -420,7 +425,7 @@ class Parser(metaclass=_Parser): ), TokenType.FARROW: lambda self, expressions: self.expression( exp.Kwarg, - this=exp.Var(this=expressions[0].name), + this=exp.var(expressions[0].name), expression=self._parse_conjunction(), ), } @@ -515,7 +520,7 @@ class Parser(metaclass=_Parser): TokenType.USE: lambda self: self.expression( exp.Use, kind=self._match_texts(("ROLE", "WAREHOUSE", "DATABASE", "SCHEMA")) - and exp.Var(this=self._prev.text), + and exp.var(self._prev.text), this=self._parse_table(schema=False), ), } @@ -634,6 +639,7 @@ class Parser(metaclass=_Parser): "TBLPROPERTIES": lambda self: self._parse_wrapped_csv(self._parse_property), "TEMP": lambda self: self.expression(exp.TemporaryProperty), "TEMPORARY": lambda self: self.expression(exp.TemporaryProperty), + "TO": lambda self: self._parse_to_table(), "TRANSIENT": lambda self: self.expression(exp.TransientProperty), "TTL": lambda self: self._parse_ttl(), "USING": lambda self: self._parse_property_assignment(exp.FileFormatProperty), @@ -710,6 +716,7 @@ class Parser(metaclass=_Parser): FUNCTION_PARSERS: t.Dict[str, t.Callable] = { "CAST": lambda self: self._parse_cast(self.STRICT_CAST), + "CONCAT": lambda self: self._parse_concat(), "CONVERT": lambda self: self._parse_convert(self.STRICT_CAST), "DECODE": lambda self: self._parse_decode(), "EXTRACT": lambda self: self._parse_extract(), @@ -755,8 +762,11 @@ class Parser(metaclass=_Parser): MODIFIABLES = (exp.Subquery, exp.Subqueryable, exp.Table) - TRANSACTION_KIND = {"DEFERRED", "IMMEDIATE", "EXCLUSIVE"} + DDL_SELECT_TOKENS = {TokenType.SELECT, TokenType.WITH, TokenType.L_PAREN} + PRE_VOLATILE_TOKENS = {TokenType.CREATE, TokenType.REPLACE, TokenType.UNIQUE} + + TRANSACTION_KIND = {"DEFERRED", "IMMEDIATE", "EXCLUSIVE"} TRANSACTION_CHARACTERISTICS = { "ISOLATION LEVEL REPEATABLE READ", "ISOLATION LEVEL READ COMMITTED", @@ -778,6 +788,8 @@ class Parser(metaclass=_Parser): STRICT_CAST = True + CONCAT_NULL_OUTPUTS_STRING = False # A NULL arg in CONCAT yields NULL by default + CONVERT_TYPE_FIRST = False PREFIXED_PIVOT_COLUMNS = False @@ -789,40 +801,39 @@ class Parser(metaclass=_Parser): __slots__ = ( "error_level", "error_message_context", + "max_errors", "sql", "errors", - "index_offset", - "unnest_column_only", - "alias_post_tablesample", - "max_errors", - "null_ordering", "_tokens", "_index", "_curr", "_next", "_prev", "_prev_comments", - "_show_trie", - "_set_trie", ) + # Autofilled + INDEX_OFFSET: int = 0 + UNNEST_COLUMN_ONLY: bool = False + ALIAS_POST_TABLESAMPLE: bool = False + STRICT_STRING_CONCAT = False + NULL_ORDERING: str = "nulls_are_small" + SHOW_TRIE: t.Dict = {} + SET_TRIE: t.Dict = {} + FORMAT_MAPPING: t.Dict[str, str] = {} + FORMAT_TRIE: t.Dict = {} + TIME_MAPPING: t.Dict[str, str] = {} + TIME_TRIE: t.Dict = {} + def __init__( self, error_level: t.Optional[ErrorLevel] = None, error_message_context: int = 100, - index_offset: int = 0, - unnest_column_only: bool = False, - alias_post_tablesample: bool = False, max_errors: int = 3, - null_ordering: t.Optional[str] = None, ): self.error_level = error_level or ErrorLevel.IMMEDIATE self.error_message_context = error_message_context - self.index_offset = index_offset - self.unnest_column_only = unnest_column_only - self.alias_post_tablesample = alias_post_tablesample self.max_errors = max_errors - self.null_ordering = null_ordering self.reset() def reset(self): @@ -843,11 +854,11 @@ class Parser(metaclass=_Parser): per parsed SQL statement. Args: - raw_tokens: the list of tokens. - sql: the original SQL string, used to produce helpful debug messages. + raw_tokens: The list of tokens. + sql: The original SQL string, used to produce helpful debug messages. Returns: - The list of syntax trees. + The list of the produced syntax trees. """ return self._parse( parse_method=self.__class__._parse_statement, raw_tokens=raw_tokens, sql=sql @@ -865,23 +876,25 @@ class Parser(metaclass=_Parser): of them, stopping at the first for which the parsing succeeds. Args: - expression_types: the expression type(s) to try and parse the token list into. - raw_tokens: the list of tokens. - sql: the original SQL string, used to produce helpful debug messages. + expression_types: The expression type(s) to try and parse the token list into. + raw_tokens: The list of tokens. + sql: The original SQL string, used to produce helpful debug messages. Returns: The target Expression. """ errors = [] - for expression_type in ensure_collection(expression_types): + for expression_type in ensure_list(expression_types): parser = self.EXPRESSION_PARSERS.get(expression_type) if not parser: raise TypeError(f"No parser registered for {expression_type}") + try: return self._parse(parser, raw_tokens, sql) except ParseError as e: e.errors[0]["into_expression"] = expression_type errors.append(e) + raise ParseError( f"Failed to parse '{sql or raw_tokens}' into {expression_types}", errors=merge_errors(errors), @@ -895,6 +908,7 @@ class Parser(metaclass=_Parser): ) -> t.List[t.Optional[exp.Expression]]: self.reset() self.sql = sql or "" + total = len(raw_tokens) chunks: t.List[t.List[Token]] = [[]] @@ -922,9 +936,7 @@ class Parser(metaclass=_Parser): return expressions def check_errors(self) -> None: - """ - Logs or raises any found errors, depending on the chosen error level setting. - """ + """Logs or raises any found errors, depending on the chosen error level setting.""" if self.error_level == ErrorLevel.WARN: for error in self.errors: logger.error(str(error)) @@ -969,39 +981,38 @@ class Parser(metaclass=_Parser): Creates a new, validated Expression. Args: - exp_class: the expression class to instantiate. - comments: an optional list of comments to attach to the expression. - kwargs: the arguments to set for the expression along with their respective values. + exp_class: The expression class to instantiate. + comments: An optional list of comments to attach to the expression. + kwargs: The arguments to set for the expression along with their respective values. Returns: The target expression. """ instance = exp_class(**kwargs) instance.add_comments(comments) if comments else self._add_comments(instance) - self.validate_expression(instance) - return instance + return self.validate_expression(instance) def _add_comments(self, expression: t.Optional[exp.Expression]) -> None: if expression and self._prev_comments: expression.add_comments(self._prev_comments) self._prev_comments = None - def validate_expression( - self, expression: exp.Expression, args: t.Optional[t.List] = None - ) -> None: + def validate_expression(self, expression: E, args: t.Optional[t.List] = None) -> E: """ - Validates an already instantiated expression, making sure that all its mandatory arguments - are set. + Validates an Expression, making sure that all its mandatory arguments are set. Args: - expression: the expression to validate. - args: an optional list of items that was used to instantiate the expression, if it's a Func. + expression: The expression to validate. + args: An optional list of items that was used to instantiate the expression, if it's a Func. + + Returns: + The validated expression. """ - if self.error_level == ErrorLevel.IGNORE: - return + if self.error_level != ErrorLevel.IGNORE: + for error_message in expression.error_messages(args): + self.raise_error(error_message) - for error_message in expression.error_messages(args): - self.raise_error(error_message) + return expression def _find_sql(self, start: Token, end: Token) -> str: return self.sql[start.start : end.end + 1] @@ -1010,6 +1021,7 @@ class Parser(metaclass=_Parser): self._index += times self._curr = seq_get(self._tokens, self._index) self._next = seq_get(self._tokens, self._index + 1) + if self._index > 0: self._prev = self._tokens[self._index - 1] self._prev_comments = self._prev.comments @@ -1031,7 +1043,6 @@ class Parser(metaclass=_Parser): self._match(TokenType.ON) kind = self._match_set(self.CREATABLES) and self._prev - if not kind: return self._parse_as_command(start) @@ -1050,6 +1061,12 @@ class Parser(metaclass=_Parser): exp.Comment, this=this, kind=kind.text, expression=self._parse_string(), exists=exists ) + def _parse_to_table( + self, + ) -> exp.ToTableProperty: + table = self._parse_table_parts(schema=True) + return self.expression(exp.ToTableProperty, this=table) + # https://clickhouse.com/docs/en/engines/table-engines/mergetree-family/mergetree#mergetree-table-ttl def _parse_ttl(self) -> exp.Expression: def _parse_ttl_action() -> t.Optional[exp.Expression]: @@ -1102,10 +1119,11 @@ class Parser(metaclass=_Parser): expression = self._parse_set_operations(expression) if expression else self._parse_select() return self._parse_query_modifiers(expression) - def _parse_drop(self) -> t.Optional[exp.Drop | exp.Command]: + def _parse_drop(self) -> exp.Drop | exp.Command: start = self._prev temporary = self._match(TokenType.TEMPORARY) materialized = self._match_text_seq("MATERIALIZED") + kind = self._match_set(self.CREATABLES) and self._prev.text if not kind: return self._parse_as_command(start) @@ -1129,21 +1147,23 @@ class Parser(metaclass=_Parser): and self._match(TokenType.EXISTS) ) - def _parse_create(self) -> t.Optional[exp.Expression]: + def _parse_create(self) -> exp.Create | exp.Command: + # Note: this can't be None because we've matched a statement parser start = self._prev - replace = self._prev.text.upper() == "REPLACE" or self._match_pair( + replace = start.text.upper() == "REPLACE" or self._match_pair( TokenType.OR, TokenType.REPLACE ) unique = self._match(TokenType.UNIQUE) if self._match_pair(TokenType.TABLE, TokenType.FUNCTION, advance=False): - self._match(TokenType.TABLE) + self._advance() properties = None create_token = self._match_set(self.CREATABLES) and self._prev if not create_token: - properties = self._parse_properties() # exp.Properties.Location.POST_CREATE + # exp.Properties.Location.POST_CREATE + properties = self._parse_properties() create_token = self._match_set(self.CREATABLES) and self._prev if not properties or not create_token: @@ -1157,7 +1177,7 @@ class Parser(metaclass=_Parser): begin = None clone = None - def extend_props(temp_props: t.Optional[exp.Expression]) -> None: + def extend_props(temp_props: t.Optional[exp.Properties]) -> None: nonlocal properties if properties and temp_props: properties.expressions.extend(temp_props.expressions) @@ -1166,6 +1186,8 @@ class Parser(metaclass=_Parser): if create_token.token_type in (TokenType.FUNCTION, TokenType.PROCEDURE): this = self._parse_user_defined_function(kind=create_token.token_type) + + # exp.Properties.Location.POST_SCHEMA ("schema" here is the UDF's type signature) extend_props(self._parse_properties()) self._match(TokenType.ALIAS) @@ -1190,13 +1212,8 @@ class Parser(metaclass=_Parser): extend_props(self._parse_properties()) self._match(TokenType.ALIAS) - - # exp.Properties.Location.POST_ALIAS - if not ( - self._match(TokenType.SELECT, advance=False) - or self._match(TokenType.WITH, advance=False) - or self._match(TokenType.L_PAREN, advance=False) - ): + if not self._match_set(self.DDL_SELECT_TOKENS, advance=False): + # exp.Properties.Location.POST_ALIAS extend_props(self._parse_properties()) expression = self._parse_ddl_select() @@ -1206,7 +1223,7 @@ class Parser(metaclass=_Parser): while True: index = self._parse_index() - # exp.Properties.Location.POST_EXPRESSION or exp.Properties.Location.POST_INDEX + # exp.Properties.Location.POST_EXPRESSION and POST_INDEX extend_props(self._parse_properties()) if not index: @@ -1296,7 +1313,7 @@ class Parser(metaclass=_Parser): return None - def _parse_stored(self) -> exp.Expression: + def _parse_stored(self) -> exp.FileFormatProperty: self._match(TokenType.ALIAS) input_format = self._parse_string() if self._match_text_seq("INPUTFORMAT") else None @@ -1311,14 +1328,13 @@ class Parser(metaclass=_Parser): else self._parse_var_or_string() or self._parse_number() or self._parse_id_var(), ) - def _parse_property_assignment(self, exp_class: t.Type[exp.Expression]) -> exp.Expression: + def _parse_property_assignment(self, exp_class: t.Type[E]) -> E: self._match(TokenType.EQ) self._match(TokenType.ALIAS) return self.expression(exp_class, this=self._parse_field()) - def _parse_properties(self, before: t.Optional[bool] = None) -> t.Optional[exp.Expression]: + def _parse_properties(self, before: t.Optional[bool] = None) -> t.Optional[exp.Properties]: properties = [] - while True: if before: prop = self._parse_property_before() @@ -1335,29 +1351,25 @@ class Parser(metaclass=_Parser): return None - def _parse_fallback(self, no: bool = False) -> exp.Expression: + def _parse_fallback(self, no: bool = False) -> exp.FallbackProperty: return self.expression( exp.FallbackProperty, no=no, protection=self._match_text_seq("PROTECTION") ) - def _parse_volatile_property(self) -> exp.Expression: + def _parse_volatile_property(self) -> exp.VolatileProperty | exp.StabilityProperty: if self._index >= 2: pre_volatile_token = self._tokens[self._index - 2] else: pre_volatile_token = None - if pre_volatile_token and pre_volatile_token.token_type in ( - TokenType.CREATE, - TokenType.REPLACE, - TokenType.UNIQUE, - ): + if pre_volatile_token and pre_volatile_token.token_type in self.PRE_VOLATILE_TOKENS: return exp.VolatileProperty() return self.expression(exp.StabilityProperty, this=exp.Literal.string("VOLATILE")) def _parse_with_property( self, - ) -> t.Union[t.Optional[exp.Expression], t.List[t.Optional[exp.Expression]]]: + ) -> t.Optional[exp.Expression] | t.List[t.Optional[exp.Expression]]: self._match(TokenType.WITH) if self._match(TokenType.L_PAREN, advance=False): return self._parse_wrapped_csv(self._parse_property) @@ -1376,7 +1388,7 @@ class Parser(metaclass=_Parser): return self._parse_withisolatedloading() # https://dev.mysql.com/doc/refman/8.0/en/create-view.html - def _parse_definer(self) -> t.Optional[exp.Expression]: + def _parse_definer(self) -> t.Optional[exp.DefinerProperty]: self._match(TokenType.EQ) user = self._parse_id_var() @@ -1388,18 +1400,18 @@ class Parser(metaclass=_Parser): return exp.DefinerProperty(this=f"{user}@{host}") - def _parse_withjournaltable(self) -> exp.Expression: + def _parse_withjournaltable(self) -> exp.WithJournalTableProperty: self._match(TokenType.TABLE) self._match(TokenType.EQ) return self.expression(exp.WithJournalTableProperty, this=self._parse_table_parts()) - def _parse_log(self, no: bool = False) -> exp.Expression: + def _parse_log(self, no: bool = False) -> exp.LogProperty: return self.expression(exp.LogProperty, no=no) - def _parse_journal(self, **kwargs) -> exp.Expression: + def _parse_journal(self, **kwargs) -> exp.JournalProperty: return self.expression(exp.JournalProperty, **kwargs) - def _parse_checksum(self) -> exp.Expression: + def _parse_checksum(self) -> exp.ChecksumProperty: self._match(TokenType.EQ) on = None @@ -1407,53 +1419,47 @@ class Parser(metaclass=_Parser): on = True elif self._match_text_seq("OFF"): on = False - default = self._match(TokenType.DEFAULT) - return self.expression( - exp.ChecksumProperty, - on=on, - default=default, - ) + return self.expression(exp.ChecksumProperty, on=on, default=self._match(TokenType.DEFAULT)) - def _parse_cluster(self) -> t.Optional[exp.Expression]: + def _parse_cluster(self) -> t.Optional[exp.Cluster]: if not self._match_text_seq("BY"): self._retreat(self._index - 1) return None - return self.expression( - exp.Cluster, - expressions=self._parse_csv(self._parse_ordered), - ) - def _parse_freespace(self) -> exp.Expression: + return self.expression(exp.Cluster, expressions=self._parse_csv(self._parse_ordered)) + + def _parse_freespace(self) -> exp.FreespaceProperty: self._match(TokenType.EQ) return self.expression( exp.FreespaceProperty, this=self._parse_number(), percent=self._match(TokenType.PERCENT) ) - def _parse_mergeblockratio(self, no: bool = False, default: bool = False) -> exp.Expression: + def _parse_mergeblockratio( + self, no: bool = False, default: bool = False + ) -> exp.MergeBlockRatioProperty: if self._match(TokenType.EQ): return self.expression( exp.MergeBlockRatioProperty, this=self._parse_number(), percent=self._match(TokenType.PERCENT), ) - return self.expression( - exp.MergeBlockRatioProperty, - no=no, - default=default, - ) + + return self.expression(exp.MergeBlockRatioProperty, no=no, default=default) def _parse_datablocksize( self, default: t.Optional[bool] = None, minimum: t.Optional[bool] = None, maximum: t.Optional[bool] = None, - ) -> exp.Expression: + ) -> exp.DataBlocksizeProperty: self._match(TokenType.EQ) size = self._parse_number() + units = None if self._match_texts(("BYTES", "KBYTES", "KILOBYTES")): units = self._prev.text + return self.expression( exp.DataBlocksizeProperty, size=size, @@ -1463,12 +1469,13 @@ class Parser(metaclass=_Parser): maximum=maximum, ) - def _parse_blockcompression(self) -> exp.Expression: + def _parse_blockcompression(self) -> exp.BlockCompressionProperty: self._match(TokenType.EQ) always = self._match_text_seq("ALWAYS") manual = self._match_text_seq("MANUAL") never = self._match_text_seq("NEVER") default = self._match_text_seq("DEFAULT") + autotemp = None if self._match_text_seq("AUTOTEMP"): autotemp = self._parse_schema() @@ -1482,7 +1489,7 @@ class Parser(metaclass=_Parser): autotemp=autotemp, ) - def _parse_withisolatedloading(self) -> exp.Expression: + def _parse_withisolatedloading(self) -> exp.IsolatedLoadingProperty: no = self._match_text_seq("NO") concurrent = self._match_text_seq("CONCURRENT") self._match_text_seq("ISOLATED", "LOADING") @@ -1498,7 +1505,7 @@ class Parser(metaclass=_Parser): for_none=for_none, ) - def _parse_locking(self) -> exp.Expression: + def _parse_locking(self) -> exp.LockingProperty: if self._match(TokenType.TABLE): kind = "TABLE" elif self._match(TokenType.VIEW): @@ -1553,14 +1560,14 @@ class Parser(metaclass=_Parser): return self._parse_csv(self._parse_conjunction) return [] - def _parse_partitioned_by(self) -> exp.Expression: + def _parse_partitioned_by(self) -> exp.PartitionedByProperty: self._match(TokenType.EQ) return self.expression( exp.PartitionedByProperty, this=self._parse_schema() or self._parse_bracket(self._parse_field()), ) - def _parse_withdata(self, no: bool = False) -> exp.Expression: + def _parse_withdata(self, no: bool = False) -> exp.WithDataProperty: if self._match_text_seq("AND", "STATISTICS"): statistics = True elif self._match_text_seq("AND", "NO", "STATISTICS"): @@ -1570,52 +1577,50 @@ class Parser(metaclass=_Parser): return self.expression(exp.WithDataProperty, no=no, statistics=statistics) - def _parse_no_property(self) -> t.Optional[exp.Property]: + def _parse_no_property(self) -> t.Optional[exp.NoPrimaryIndexProperty]: if self._match_text_seq("PRIMARY", "INDEX"): return exp.NoPrimaryIndexProperty() return None - def _parse_on_property(self) -> t.Optional[exp.Property]: + def _parse_on_property(self) -> t.Optional[exp.Expression]: if self._match_text_seq("COMMIT", "PRESERVE", "ROWS"): return exp.OnCommitProperty() elif self._match_text_seq("COMMIT", "DELETE", "ROWS"): return exp.OnCommitProperty(delete=True) return None - def _parse_distkey(self) -> exp.Expression: + def _parse_distkey(self) -> exp.DistKeyProperty: return self.expression(exp.DistKeyProperty, this=self._parse_wrapped(self._parse_id_var)) - def _parse_create_like(self) -> t.Optional[exp.Expression]: + def _parse_create_like(self) -> t.Optional[exp.LikeProperty]: table = self._parse_table(schema=True) + options = [] while self._match_texts(("INCLUDING", "EXCLUDING")): this = self._prev.text.upper() - id_var = self._parse_id_var() + id_var = self._parse_id_var() if not id_var: return None options.append( - self.expression( - exp.Property, - this=this, - value=exp.Var(this=id_var.this.upper()), - ) + self.expression(exp.Property, this=this, value=exp.var(id_var.this.upper())) ) + return self.expression(exp.LikeProperty, this=table, expressions=options) - def _parse_sortkey(self, compound: bool = False) -> exp.Expression: + def _parse_sortkey(self, compound: bool = False) -> exp.SortKeyProperty: return self.expression( - exp.SortKeyProperty, this=self._parse_wrapped_csv(self._parse_id_var), compound=compound + exp.SortKeyProperty, this=self._parse_wrapped_id_vars(), compound=compound ) - def _parse_character_set(self, default: bool = False) -> exp.Expression: + def _parse_character_set(self, default: bool = False) -> exp.CharacterSetProperty: self._match(TokenType.EQ) return self.expression( exp.CharacterSetProperty, this=self._parse_var_or_string(), default=default ) - def _parse_returns(self) -> exp.Expression: + def _parse_returns(self) -> exp.ReturnsProperty: value: t.Optional[exp.Expression] is_table = self._match(TokenType.TABLE) @@ -1629,19 +1634,18 @@ class Parser(metaclass=_Parser): if not self._match(TokenType.GT): self.raise_error("Expecting >") else: - value = self._parse_schema(exp.Var(this="TABLE")) + value = self._parse_schema(exp.var("TABLE")) else: value = self._parse_types() return self.expression(exp.ReturnsProperty, this=value, is_table=is_table) - def _parse_describe(self) -> exp.Expression: + def _parse_describe(self) -> exp.Describe: kind = self._match_set(self.CREATABLES) and self._prev.text this = self._parse_table() - return self.expression(exp.Describe, this=this, kind=kind) - def _parse_insert(self) -> exp.Expression: + def _parse_insert(self) -> exp.Insert: overwrite = self._match(TokenType.OVERWRITE) local = self._match_text_seq("LOCAL") alternative = None @@ -1673,11 +1677,11 @@ class Parser(metaclass=_Parser): alternative=alternative, ) - def _parse_on_conflict(self) -> t.Optional[exp.Expression]: + def _parse_on_conflict(self) -> t.Optional[exp.OnConflict]: conflict = self._match_text_seq("ON", "CONFLICT") duplicate = self._match_text_seq("ON", "DUPLICATE", "KEY") - if not (conflict or duplicate): + if not conflict and not duplicate: return None nothing = None @@ -1707,18 +1711,20 @@ class Parser(metaclass=_Parser): constraint=constraint, ) - def _parse_returning(self) -> t.Optional[exp.Expression]: + def _parse_returning(self) -> t.Optional[exp.Returning]: if not self._match(TokenType.RETURNING): return None return self.expression(exp.Returning, expressions=self._parse_csv(self._parse_column)) - def _parse_row(self) -> t.Optional[exp.Expression]: + def _parse_row(self) -> t.Optional[exp.RowFormatSerdeProperty | exp.RowFormatDelimitedProperty]: if not self._match(TokenType.FORMAT): return None return self._parse_row_format() - def _parse_row_format(self, match_row: bool = False) -> t.Optional[exp.Expression]: + def _parse_row_format( + self, match_row: bool = False + ) -> t.Optional[exp.RowFormatSerdeProperty | exp.RowFormatDelimitedProperty]: if match_row and not self._match_pair(TokenType.ROW, TokenType.FORMAT): return None @@ -1744,7 +1750,7 @@ class Parser(metaclass=_Parser): return self.expression(exp.RowFormatDelimitedProperty, **kwargs) # type: ignore - def _parse_load(self) -> exp.Expression: + def _parse_load(self) -> exp.LoadData | exp.Command: if self._match_text_seq("DATA"): local = self._match_text_seq("LOCAL") self._match_text_seq("INPATH") @@ -1764,7 +1770,7 @@ class Parser(metaclass=_Parser): ) return self._parse_as_command(self._prev) - def _parse_delete(self) -> exp.Expression: + def _parse_delete(self) -> exp.Delete: self._match(TokenType.FROM) return self.expression( @@ -1775,7 +1781,7 @@ class Parser(metaclass=_Parser): returning=self._parse_returning(), ) - def _parse_update(self) -> exp.Expression: + def _parse_update(self) -> exp.Update: return self.expression( exp.Update, **{ # type: ignore @@ -1787,22 +1793,20 @@ class Parser(metaclass=_Parser): }, ) - def _parse_uncache(self) -> exp.Expression: + def _parse_uncache(self) -> exp.Uncache: if not self._match(TokenType.TABLE): self.raise_error("Expecting TABLE after UNCACHE") return self.expression( - exp.Uncache, - exists=self._parse_exists(), - this=self._parse_table(schema=True), + exp.Uncache, exists=self._parse_exists(), this=self._parse_table(schema=True) ) - def _parse_cache(self) -> exp.Expression: + def _parse_cache(self) -> exp.Cache: lazy = self._match_text_seq("LAZY") self._match(TokenType.TABLE) table = self._parse_table(schema=True) - options = [] + options = [] if self._match_text_seq("OPTIONS"): self._match_l_paren() k = self._parse_string() @@ -1820,7 +1824,7 @@ class Parser(metaclass=_Parser): expression=self._parse_select(nested=True), ) - def _parse_partition(self) -> t.Optional[exp.Expression]: + def _parse_partition(self) -> t.Optional[exp.Partition]: if not self._match(TokenType.PARTITION): return None @@ -1828,7 +1832,7 @@ class Parser(metaclass=_Parser): exp.Partition, expressions=self._parse_wrapped_csv(self._parse_conjunction) ) - def _parse_value(self) -> exp.Expression: + def _parse_value(self) -> exp.Tuple: if self._match(TokenType.L_PAREN): expressions = self._parse_csv(self._parse_conjunction) self._match_r_paren() @@ -1926,7 +1930,7 @@ class Parser(metaclass=_Parser): return self._parse_set_operations(this) - def _parse_with(self, skip_with_token: bool = False) -> t.Optional[exp.Expression]: + def _parse_with(self, skip_with_token: bool = False) -> t.Optional[exp.With]: if not skip_with_token and not self._match(TokenType.WITH): return None @@ -1946,22 +1950,19 @@ class Parser(metaclass=_Parser): exp.With, comments=comments, expressions=expressions, recursive=recursive ) - def _parse_cte(self) -> exp.Expression: + def _parse_cte(self) -> exp.CTE: alias = self._parse_table_alias() if not alias or not alias.this: self.raise_error("Expected CTE to have alias") self._match(TokenType.ALIAS) - return self.expression( - exp.CTE, - this=self._parse_wrapped(self._parse_statement), - alias=alias, + exp.CTE, this=self._parse_wrapped(self._parse_statement), alias=alias ) def _parse_table_alias( self, alias_tokens: t.Optional[t.Collection[TokenType]] = None - ) -> t.Optional[exp.Expression]: + ) -> t.Optional[exp.TableAlias]: any_token = self._match(TokenType.ALIAS) alias = ( self._parse_id_var(any_token=any_token, tokens=alias_tokens or self.TABLE_ALIAS_TOKENS) @@ -1982,9 +1983,10 @@ class Parser(metaclass=_Parser): def _parse_subquery( self, this: t.Optional[exp.Expression], parse_alias: bool = True - ) -> t.Optional[exp.Expression]: + ) -> t.Optional[exp.Subquery]: if not this: return None + return self.expression( exp.Subquery, this=this, @@ -2000,19 +2002,25 @@ class Parser(metaclass=_Parser): expression = parser(self) if expression: + if key == "limit": + offset = expression.args.pop("offset", None) + if offset: + this.set("offset", exp.Offset(expression=offset)) this.set(key, expression) return this - def _parse_hint(self) -> t.Optional[exp.Expression]: + def _parse_hint(self) -> t.Optional[exp.Hint]: if self._match(TokenType.HINT): hints = self._parse_csv(self._parse_function) + if not self._match_pair(TokenType.STAR, TokenType.SLASH): self.raise_error("Expected */ after HINT") + return self.expression(exp.Hint, expressions=hints) return None - def _parse_into(self) -> t.Optional[exp.Expression]: + def _parse_into(self) -> t.Optional[exp.Into]: if not self._match(TokenType.INTO): return None @@ -2039,7 +2047,7 @@ class Parser(metaclass=_Parser): this=self._parse_query_modifiers(this) if modifiers else this, ) - def _parse_match_recognize(self) -> t.Optional[exp.Expression]: + def _parse_match_recognize(self) -> t.Optional[exp.MatchRecognize]: if not self._match(TokenType.MATCH_RECOGNIZE): return None @@ -2052,7 +2060,7 @@ class Parser(metaclass=_Parser): ) if self._match_text_seq("ONE", "ROW", "PER", "MATCH"): - rows = exp.Var(this="ONE ROW PER MATCH") + rows = exp.var("ONE ROW PER MATCH") elif self._match_text_seq("ALL", "ROWS", "PER", "MATCH"): text = "ALL ROWS PER MATCH" if self._match_text_seq("SHOW", "EMPTY", "MATCHES"): @@ -2061,7 +2069,7 @@ class Parser(metaclass=_Parser): text += f" OMIT EMPTY MATCHES" elif self._match_text_seq("WITH", "UNMATCHED", "ROWS"): text += f" WITH UNMATCHED ROWS" - rows = exp.Var(this=text) + rows = exp.var(text) else: rows = None @@ -2075,7 +2083,7 @@ class Parser(metaclass=_Parser): text += f" TO FIRST {self._advance_any().text}" # type: ignore elif self._match_text_seq("TO", "LAST"): text += f" TO LAST {self._advance_any().text}" # type: ignore - after = exp.Var(this=text) + after = exp.var(text) else: after = None @@ -2093,11 +2101,14 @@ class Parser(metaclass=_Parser): paren += 1 if self._curr.token_type == TokenType.R_PAREN: paren -= 1 + end = self._prev self._advance() + if paren > 0: self.raise_error("Expecting )", self._curr) - pattern = exp.Var(this=self._find_sql(start, end)) + + pattern = exp.var(self._find_sql(start, end)) else: pattern = None @@ -2127,7 +2138,7 @@ class Parser(metaclass=_Parser): alias=self._parse_table_alias(), ) - def _parse_lateral(self) -> t.Optional[exp.Expression]: + def _parse_lateral(self) -> t.Optional[exp.Lateral]: outer_apply = self._match_pair(TokenType.OUTER, TokenType.APPLY) cross_apply = self._match_pair(TokenType.CROSS, TokenType.APPLY) @@ -2150,24 +2161,19 @@ class Parser(metaclass=_Parser): expression=self._parse_function() or self._parse_id_var(any_token=False), ) - table_alias: t.Optional[exp.Expression] - if view: table = self._parse_id_var(any_token=False) columns = self._parse_csv(self._parse_id_var) if self._match(TokenType.ALIAS) else [] - table_alias = self.expression(exp.TableAlias, this=table, columns=columns) + table_alias: t.Optional[exp.TableAlias] = self.expression( + exp.TableAlias, this=table, columns=columns + ) + elif isinstance(this, exp.Subquery) and this.alias: + # Ensures parity between the Subquery's and the Lateral's "alias" args + table_alias = this.args["alias"].copy() else: table_alias = self._parse_table_alias() - expression = self.expression( - exp.Lateral, - this=this, - view=view, - outer=outer, - alias=table_alias, - ) - - return expression + return self.expression(exp.Lateral, this=this, view=view, outer=outer, alias=table_alias) def _parse_join_parts( self, @@ -2178,7 +2184,7 @@ class Parser(metaclass=_Parser): self._match_set(self.JOIN_KINDS) and self._prev, ) - def _parse_join(self, skip_join_token: bool = False) -> t.Optional[exp.Expression]: + def _parse_join(self, skip_join_token: bool = False) -> t.Optional[exp.Join]: if self._match(TokenType.COMMA): return self.expression(exp.Join, this=self._parse_table()) @@ -2223,7 +2229,7 @@ class Parser(metaclass=_Parser): def _parse_index( self, index: t.Optional[exp.Expression] = None, - ) -> t.Optional[exp.Expression]: + ) -> t.Optional[exp.Index]: if index: unique = None primary = None @@ -2236,11 +2242,15 @@ class Parser(metaclass=_Parser): unique = self._match(TokenType.UNIQUE) primary = self._match_text_seq("PRIMARY") amp = self._match_text_seq("AMP") + if not self._match(TokenType.INDEX): return None + index = self._parse_id_var() table = None + using = self._parse_field() if self._match(TokenType.USING) else None + if self._match(TokenType.L_PAREN, advance=False): columns = self._parse_wrapped_csv(self._parse_ordered) else: @@ -2250,6 +2260,7 @@ class Parser(metaclass=_Parser): exp.Index, this=index, table=table, + using=using, columns=columns, unique=unique, primary=primary, @@ -2259,7 +2270,7 @@ class Parser(metaclass=_Parser): def _parse_table_part(self, schema: bool = False) -> t.Optional[exp.Expression]: return ( - (not schema and self._parse_function()) + (not schema and self._parse_function(optional_parens=False)) or self._parse_id_var(any_token=False) or self._parse_string_as_identifier() or self._parse_placeholder() @@ -2314,7 +2325,7 @@ class Parser(metaclass=_Parser): if schema: return self._parse_schema(this=this) - if self.alias_post_tablesample: + if self.ALIAS_POST_TABLESAMPLE: table_sample = self._parse_table_sample() alias = self._parse_table_alias(alias_tokens=alias_tokens or self.TABLE_ALIAS_TOKENS) @@ -2331,7 +2342,7 @@ class Parser(metaclass=_Parser): ) self._match_r_paren() - if not self.alias_post_tablesample: + if not self.ALIAS_POST_TABLESAMPLE: table_sample = self._parse_table_sample() if table_sample: @@ -2340,46 +2351,47 @@ class Parser(metaclass=_Parser): return this - def _parse_unnest(self) -> t.Optional[exp.Expression]: + def _parse_unnest(self, with_alias: bool = True) -> t.Optional[exp.Unnest]: if not self._match(TokenType.UNNEST): return None expressions = self._parse_wrapped_csv(self._parse_type) ordinality = self._match_pair(TokenType.WITH, TokenType.ORDINALITY) - alias = self._parse_table_alias() - if alias and self.unnest_column_only: + alias = self._parse_table_alias() if with_alias else None + + if alias and self.UNNEST_COLUMN_ONLY: if alias.args.get("columns"): self.raise_error("Unexpected extra column alias in unnest.") + alias.set("columns", [alias.this]) alias.set("this", None) offset = None if self._match_pair(TokenType.WITH, TokenType.OFFSET): self._match(TokenType.ALIAS) - offset = self._parse_id_var() or exp.Identifier(this="offset") + offset = self._parse_id_var() or exp.to_identifier("offset") return self.expression( - exp.Unnest, - expressions=expressions, - ordinality=ordinality, - alias=alias, - offset=offset, + exp.Unnest, expressions=expressions, ordinality=ordinality, alias=alias, offset=offset ) - def _parse_derived_table_values(self) -> t.Optional[exp.Expression]: + def _parse_derived_table_values(self) -> t.Optional[exp.Values]: is_derived = self._match_pair(TokenType.L_PAREN, TokenType.VALUES) if not is_derived and not self._match(TokenType.VALUES): return None expressions = self._parse_csv(self._parse_value) + alias = self._parse_table_alias() if is_derived: self._match_r_paren() - return self.expression(exp.Values, expressions=expressions, alias=self._parse_table_alias()) + return self.expression( + exp.Values, expressions=expressions, alias=alias or self._parse_table_alias() + ) - def _parse_table_sample(self, as_modifier: bool = False) -> t.Optional[exp.Expression]: + def _parse_table_sample(self, as_modifier: bool = False) -> t.Optional[exp.TableSample]: if not self._match(TokenType.TABLE_SAMPLE) and not ( as_modifier and self._match_text_seq("USING", "SAMPLE") ): @@ -2456,7 +2468,7 @@ class Parser(metaclass=_Parser): exp.Pivot, this=this, expressions=expressions, using=using, group=group ) - def _parse_pivot(self) -> t.Optional[exp.Expression]: + def _parse_pivot(self) -> t.Optional[exp.Pivot]: index = self._index if self._match(TokenType.PIVOT): @@ -2519,7 +2531,7 @@ class Parser(metaclass=_Parser): def _pivot_column_names(self, aggregations: t.List[exp.Expression]) -> t.List[str]: return [agg.alias for agg in aggregations] - def _parse_where(self, skip_where_token: bool = False) -> t.Optional[exp.Expression]: + def _parse_where(self, skip_where_token: bool = False) -> t.Optional[exp.Where]: if not skip_where_token and not self._match(TokenType.WHERE): return None @@ -2527,7 +2539,7 @@ class Parser(metaclass=_Parser): exp.Where, comments=self._prev_comments, this=self._parse_conjunction() ) - def _parse_group(self, skip_group_by_token: bool = False) -> t.Optional[exp.Expression]: + def _parse_group(self, skip_group_by_token: bool = False) -> t.Optional[exp.Group]: if not skip_group_by_token and not self._match(TokenType.GROUP_BY): return None @@ -2578,12 +2590,12 @@ class Parser(metaclass=_Parser): return self._parse_column() - def _parse_having(self, skip_having_token: bool = False) -> t.Optional[exp.Expression]: + def _parse_having(self, skip_having_token: bool = False) -> t.Optional[exp.Having]: if not skip_having_token and not self._match(TokenType.HAVING): return None return self.expression(exp.Having, this=self._parse_conjunction()) - def _parse_qualify(self) -> t.Optional[exp.Expression]: + def _parse_qualify(self) -> t.Optional[exp.Qualify]: if not self._match(TokenType.QUALIFY): return None return self.expression(exp.Qualify, this=self._parse_conjunction()) @@ -2598,16 +2610,15 @@ class Parser(metaclass=_Parser): exp.Order, this=this, expressions=self._parse_csv(self._parse_ordered) ) - def _parse_sort( - self, exp_class: t.Type[exp.Expression], *texts: str - ) -> t.Optional[exp.Expression]: + def _parse_sort(self, exp_class: t.Type[E], *texts: str) -> t.Optional[E]: if not self._match_text_seq(*texts): return None return self.expression(exp_class, expressions=self._parse_csv(self._parse_ordered)) - def _parse_ordered(self) -> exp.Expression: + def _parse_ordered(self) -> exp.Ordered: this = self._parse_conjunction() self._match(TokenType.ASC) + is_desc = self._match(TokenType.DESC) is_nulls_first = self._match_text_seq("NULLS", "FIRST") is_nulls_last = self._match_text_seq("NULLS", "LAST") @@ -2615,13 +2626,14 @@ class Parser(metaclass=_Parser): asc = not desc nulls_first = is_nulls_first or False explicitly_null_ordered = is_nulls_first or is_nulls_last + if ( not explicitly_null_ordered and ( - (asc and self.null_ordering == "nulls_are_small") - or (desc and self.null_ordering != "nulls_are_small") + (asc and self.NULL_ORDERING == "nulls_are_small") + or (desc and self.NULL_ORDERING != "nulls_are_small") ) - and self.null_ordering != "nulls_are_last" + and self.NULL_ORDERING != "nulls_are_last" ): nulls_first = True @@ -2632,9 +2644,15 @@ class Parser(metaclass=_Parser): ) -> t.Optional[exp.Expression]: if self._match(TokenType.TOP if top else TokenType.LIMIT): limit_paren = self._match(TokenType.L_PAREN) - limit_exp = self.expression( - exp.Limit, this=this, expression=self._parse_number() if top else self._parse_term() - ) + expression = self._parse_number() if top else self._parse_term() + + if self._match(TokenType.COMMA): + offset = expression + expression = self._parse_term() + else: + offset = None + + limit_exp = self.expression(exp.Limit, this=this, expression=expression, offset=offset) if limit_paren: self._match_r_paren() @@ -2667,17 +2685,15 @@ class Parser(metaclass=_Parser): return this def _parse_offset(self, this: t.Optional[exp.Expression] = None) -> t.Optional[exp.Expression]: - if not self._match_set((TokenType.OFFSET, TokenType.COMMA)): + if not self._match(TokenType.OFFSET): return this count = self._parse_number() self._match_set((TokenType.ROW, TokenType.ROWS)) return self.expression(exp.Offset, this=this, expression=count) - def _parse_locks(self) -> t.List[exp.Expression]: - # Lists are invariant, so we need to use a type hint here - locks: t.List[exp.Expression] = [] - + def _parse_locks(self) -> t.List[exp.Lock]: + locks = [] while True: if self._match_text_seq("FOR", "UPDATE"): update = True @@ -2768,6 +2784,7 @@ class Parser(metaclass=_Parser): def _parse_is(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: index = self._index - 1 negate = self._match(TokenType.NOT) + if self._match_text_seq("DISTINCT", "FROM"): klass = exp.NullSafeEQ if negate else exp.NullSafeNEQ return self.expression(klass, this=this, expression=self._parse_expression()) @@ -2781,7 +2798,7 @@ class Parser(metaclass=_Parser): return self.expression(exp.Not, this=this) if negate else this def _parse_in(self, this: t.Optional[exp.Expression], alias: bool = False) -> exp.In: - unnest = self._parse_unnest() + unnest = self._parse_unnest(with_alias=False) if unnest: this = self.expression(exp.In, this=this, unnest=unnest) elif self._match(TokenType.L_PAREN): @@ -2798,7 +2815,7 @@ class Parser(metaclass=_Parser): return this - def _parse_between(self, this: exp.Expression) -> exp.Expression: + def _parse_between(self, this: exp.Expression) -> exp.Between: low = self._parse_bitwise() self._match(TokenType.AND) high = self._parse_bitwise() @@ -2809,7 +2826,7 @@ class Parser(metaclass=_Parser): return this return self.expression(exp.Escape, this=this, expression=self._parse_string()) - def _parse_interval(self) -> t.Optional[exp.Expression]: + def _parse_interval(self) -> t.Optional[exp.Interval]: if not self._match(TokenType.INTERVAL): return None @@ -2840,9 +2857,7 @@ class Parser(metaclass=_Parser): while True: if self._match_set(self.BITWISE): this = self.expression( - self.BITWISE[self._prev.token_type], - this=this, - expression=self._parse_term(), + self.BITWISE[self._prev.token_type], this=this, expression=self._parse_term() ) elif self._match_pair(TokenType.LT, TokenType.LT): this = self.expression( @@ -2890,7 +2905,7 @@ class Parser(metaclass=_Parser): return this - def _parse_type_size(self) -> t.Optional[exp.Expression]: + def _parse_type_size(self) -> t.Optional[exp.DataTypeSize]: this = self._parse_type() if not this: return None @@ -2926,6 +2941,8 @@ class Parser(metaclass=_Parser): expressions = self._parse_csv( lambda: self._parse_types(check_func=check_func, schema=schema) ) + elif type_token in self.ENUM_TYPE_TOKENS: + expressions = self._parse_csv(self._parse_primary) else: expressions = self._parse_csv(self._parse_type_size) @@ -2943,11 +2960,7 @@ class Parser(metaclass=_Parser): ) while self._match_pair(TokenType.L_BRACKET, TokenType.R_BRACKET): - this = exp.DataType( - this=exp.DataType.Type.ARRAY, - expressions=[this], - nested=True, - ) + this = exp.DataType(this=exp.DataType.Type.ARRAY, expressions=[this], nested=True) return this @@ -2973,23 +2986,14 @@ class Parser(metaclass=_Parser): value: t.Optional[exp.Expression] = None if type_token in self.TIMESTAMPS: - if self._match_text_seq("WITH", "TIME", "ZONE") or type_token == TokenType.TIMESTAMPTZ: + if self._match_text_seq("WITH", "TIME", "ZONE"): + maybe_func = False value = exp.DataType(this=exp.DataType.Type.TIMESTAMPTZ, expressions=expressions) - elif ( - self._match_text_seq("WITH", "LOCAL", "TIME", "ZONE") - or type_token == TokenType.TIMESTAMPLTZ - ): + elif self._match_text_seq("WITH", "LOCAL", "TIME", "ZONE"): + maybe_func = False value = exp.DataType(this=exp.DataType.Type.TIMESTAMPLTZ, expressions=expressions) elif self._match_text_seq("WITHOUT", "TIME", "ZONE"): - if type_token == TokenType.TIME: - value = exp.DataType(this=exp.DataType.Type.TIME, expressions=expressions) - else: - value = exp.DataType(this=exp.DataType.Type.TIMESTAMP, expressions=expressions) - - maybe_func = maybe_func and value is None - - if value is None: - value = exp.DataType(this=exp.DataType.Type.TIMESTAMP, expressions=expressions) + maybe_func = False elif type_token == TokenType.INTERVAL: unit = self._parse_var() @@ -3037,7 +3041,7 @@ class Parser(metaclass=_Parser): return self._parse_bracket(this) return self._parse_column_ops(this) - def _parse_column_ops(self, this: exp.Expression) -> exp.Expression: + def _parse_column_ops(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: this = self._parse_bracket(this) while self._match_set(self.COLUMN_OPERATORS): @@ -3057,7 +3061,7 @@ class Parser(metaclass=_Parser): else exp.Literal.string(value) ) else: - field = self._parse_field(anonymous_func=True) + field = self._parse_field(anonymous_func=True, any_token=True) if isinstance(field, exp.Func): # bigquery allows function calls like x.y.count(...) @@ -3089,8 +3093,10 @@ class Parser(metaclass=_Parser): expressions = [primary] while self._match(TokenType.STRING): expressions.append(exp.Literal.string(self._prev.text)) + if len(expressions) > 1: return self.expression(exp.Concat, expressions=expressions) + return primary if self._match_pair(TokenType.DOT, TokenType.NUMBER): @@ -3118,8 +3124,8 @@ class Parser(metaclass=_Parser): if this: this.add_comments(comments) - self._match_r_paren(expression=this) + self._match_r_paren(expression=this) return this return None @@ -3137,18 +3143,21 @@ class Parser(metaclass=_Parser): ) def _parse_function( - self, functions: t.Optional[t.Dict[str, t.Callable]] = None, anonymous: bool = False + self, + functions: t.Optional[t.Dict[str, t.Callable]] = None, + anonymous: bool = False, + optional_parens: bool = True, ) -> t.Optional[exp.Expression]: if not self._curr: return None token_type = self._curr.token_type - if self._match_set(self.NO_PAREN_FUNCTION_PARSERS): + if optional_parens and self._match_set(self.NO_PAREN_FUNCTION_PARSERS): return self.NO_PAREN_FUNCTION_PARSERS[token_type](self) if not self._next or self._next.token_type != TokenType.L_PAREN: - if token_type in self.NO_PAREN_FUNCTIONS: + if optional_parens and token_type in self.NO_PAREN_FUNCTIONS: self._advance() return self.expression(self.NO_PAREN_FUNCTIONS[token_type]) @@ -3182,8 +3191,7 @@ class Parser(metaclass=_Parser): args = self._parse_csv(lambda: self._parse_lambda(alias=alias)) if function and not anonymous: - this = function(args) - self.validate_expression(this, args) + this = self.validate_expression(function(args), args) else: this = self.expression(exp.Anonymous, this=this, expressions=args) @@ -3210,14 +3218,14 @@ class Parser(metaclass=_Parser): exp.UserDefinedFunction, this=this, expressions=expressions, wrapped=True ) - def _parse_introducer(self, token: Token) -> t.Optional[exp.Expression]: + def _parse_introducer(self, token: Token) -> exp.Introducer | exp.Identifier: literal = self._parse_primary() if literal: return self.expression(exp.Introducer, this=token.text, expression=literal) return self.expression(exp.Identifier, this=token.text) - def _parse_session_parameter(self) -> exp.Expression: + def _parse_session_parameter(self) -> exp.SessionParameter: kind = None this = self._parse_id_var() or self._parse_primary() @@ -3255,7 +3263,7 @@ class Parser(metaclass=_Parser): if isinstance(this, exp.EQ): left = this.this if isinstance(left, exp.Column): - left.replace(exp.Var(this=left.text("this"))) + left.replace(exp.var(left.text("this"))) return self._parse_limit(self._parse_order(self._parse_respect_or_ignore_nulls(this))) @@ -3279,6 +3287,7 @@ class Parser(metaclass=_Parser): lambda: self._parse_constraint() or self._parse_column_def(self._parse_field(any_token=True)) ) + self._match_r_paren() return self.expression(exp.Schema, this=this, expressions=args) @@ -3286,6 +3295,7 @@ class Parser(metaclass=_Parser): # column defs are not really columns, they're identifiers if isinstance(this, exp.Column): this = this.this + kind = self._parse_types(schema=True) if self._match_text_seq("FOR", "ORDINALITY"): @@ -3303,7 +3313,9 @@ class Parser(metaclass=_Parser): return self.expression(exp.ColumnDef, this=this, kind=kind, constraints=constraints) - def _parse_auto_increment(self) -> exp.Expression: + def _parse_auto_increment( + self, + ) -> exp.GeneratedAsIdentityColumnConstraint | exp.AutoIncrementColumnConstraint: start = None increment = None @@ -3321,7 +3333,7 @@ class Parser(metaclass=_Parser): return exp.AutoIncrementColumnConstraint() - def _parse_compress(self) -> exp.Expression: + def _parse_compress(self) -> exp.CompressColumnConstraint: if self._match(TokenType.L_PAREN, advance=False): return self.expression( exp.CompressColumnConstraint, this=self._parse_wrapped_csv(self._parse_bitwise) @@ -3329,7 +3341,7 @@ class Parser(metaclass=_Parser): return self.expression(exp.CompressColumnConstraint, this=self._parse_bitwise()) - def _parse_generated_as_identity(self) -> exp.Expression: + def _parse_generated_as_identity(self) -> exp.GeneratedAsIdentityColumnConstraint: if self._match_text_seq("BY", "DEFAULT"): on_null = self._match_pair(TokenType.ON, TokenType.NULL) this = self.expression( @@ -3364,11 +3376,13 @@ class Parser(metaclass=_Parser): return this - def _parse_inline(self) -> t.Optional[exp.Expression]: + def _parse_inline(self) -> exp.InlineLengthColumnConstraint: self._match_text_seq("LENGTH") return self.expression(exp.InlineLengthColumnConstraint, this=self._parse_bitwise()) - def _parse_not_constraint(self) -> t.Optional[exp.Expression]: + def _parse_not_constraint( + self, + ) -> t.Optional[exp.NotNullColumnConstraint | exp.CaseSpecificColumnConstraint]: if self._match_text_seq("NULL"): return self.expression(exp.NotNullColumnConstraint) if self._match_text_seq("CASESPECIFIC"): @@ -3417,7 +3431,7 @@ class Parser(metaclass=_Parser): return self.CONSTRAINT_PARSERS[constraint](self) - def _parse_unique(self) -> exp.Expression: + def _parse_unique(self) -> exp.UniqueColumnConstraint: self._match_text_seq("KEY") return self.expression( exp.UniqueColumnConstraint, this=self._parse_schema(self._parse_id_var(any_token=False)) @@ -3460,7 +3474,7 @@ class Parser(metaclass=_Parser): return options - def _parse_references(self, match: bool = True) -> t.Optional[exp.Expression]: + def _parse_references(self, match: bool = True) -> t.Optional[exp.Reference]: if match and not self._match(TokenType.REFERENCES): return None @@ -3473,7 +3487,7 @@ class Parser(metaclass=_Parser): options = self._parse_key_constraint_options() return self.expression(exp.Reference, this=this, expressions=expressions, options=options) - def _parse_foreign_key(self) -> exp.Expression: + def _parse_foreign_key(self) -> exp.ForeignKey: expressions = self._parse_wrapped_id_vars() reference = self._parse_references() options = {} @@ -3501,7 +3515,7 @@ class Parser(metaclass=_Parser): def _parse_primary_key( self, wrapped_optional: bool = False, in_props: bool = False - ) -> exp.Expression: + ) -> exp.PrimaryKeyColumnConstraint | exp.PrimaryKey: desc = ( self._match_set((TokenType.ASC, TokenType.DESC)) and self._prev.token_type == TokenType.DESC @@ -3514,15 +3528,7 @@ class Parser(metaclass=_Parser): options = self._parse_key_constraint_options() return self.expression(exp.PrimaryKey, expressions=expressions, options=options) - @t.overload - def _parse_bracket(self, this: exp.Expression) -> exp.Expression: - ... - - @t.overload def _parse_bracket(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: - ... - - def _parse_bracket(self, this): if not self._match_set((TokenType.L_BRACKET, TokenType.L_BRACE)): return this @@ -3541,7 +3547,7 @@ class Parser(metaclass=_Parser): elif not this or this.name.upper() == "ARRAY": this = self.expression(exp.Array, expressions=expressions) else: - expressions = apply_index_offset(this, expressions, -self.index_offset) + expressions = apply_index_offset(this, expressions, -self.INDEX_OFFSET) this = self.expression(exp.Bracket, this=this, expressions=expressions) if not self._match(TokenType.R_BRACKET) and bracket_kind == TokenType.L_BRACKET: @@ -3582,8 +3588,7 @@ class Parser(metaclass=_Parser): def _parse_if(self) -> t.Optional[exp.Expression]: if self._match(TokenType.L_PAREN): args = self._parse_csv(self._parse_conjunction) - this = exp.If.from_arg_list(args) - self.validate_expression(this, args) + this = self.validate_expression(exp.If.from_arg_list(args), args) self._match_r_paren() else: index = self._index - 1 @@ -3601,7 +3606,7 @@ class Parser(metaclass=_Parser): return self._parse_window(this) - def _parse_extract(self) -> exp.Expression: + def _parse_extract(self) -> exp.Extract: this = self._parse_function() or self._parse_var() or self._parse_type() if self._match(TokenType.FROM): @@ -3630,9 +3635,37 @@ class Parser(metaclass=_Parser): elif to.this == exp.DataType.Type.CHAR: if self._match(TokenType.CHARACTER_SET): to = self.expression(exp.CharacterSet, this=self._parse_var_or_string()) + elif to.this in exp.DataType.TEMPORAL_TYPES and self._match(TokenType.FORMAT): + fmt = self._parse_string() + + return self.expression( + exp.StrToDate if to.this == exp.DataType.Type.DATE else exp.StrToTime, + this=this, + format=exp.Literal.string( + format_time( + fmt.this if fmt else "", + self.FORMAT_MAPPING or self.TIME_MAPPING, + self.FORMAT_TRIE or self.TIME_TRIE, + ) + ), + ) return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to) + def _parse_concat(self) -> t.Optional[exp.Expression]: + args = self._parse_csv(self._parse_conjunction) + if self.CONCAT_NULL_OUTPUTS_STRING: + args = [exp.func("COALESCE", arg, exp.Literal.string("")) for arg in args] + + # Some dialects (e.g. Trino) don't allow a single-argument CONCAT call, so when + # we find such a call we replace it with its argument. + if len(args) == 1: + return args[0] + + return self.expression( + exp.Concat if self.STRICT_STRING_CONCAT else exp.SafeConcat, expressions=args + ) + def _parse_string_agg(self) -> exp.Expression: expression: t.Optional[exp.Expression] @@ -3654,9 +3687,7 @@ class Parser(metaclass=_Parser): # the STRING_AGG call is parsed like in MySQL / SQLite and can thus be transpiled more easily to them. if not self._match_text_seq("WITHIN", "GROUP"): self._retreat(index) - this = exp.GroupConcat.from_arg_list(args) - self.validate_expression(this, args) - return this + return self.validate_expression(exp.GroupConcat.from_arg_list(args), args) self._match_l_paren() # The corresponding match_r_paren will be called in parse_function (caller) order = self._parse_order(this=expression) @@ -3679,7 +3710,7 @@ class Parser(metaclass=_Parser): return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to) - def _parse_decode(self) -> t.Optional[exp.Expression]: + def _parse_decode(self) -> t.Optional[exp.Decode | exp.Case]: """ There are generally two variants of the DECODE function: @@ -3726,18 +3757,20 @@ class Parser(metaclass=_Parser): return exp.Case(ifs=ifs, default=expressions[-1] if len(expressions) % 2 == 1 else None) - def _parse_json_key_value(self) -> t.Optional[exp.Expression]: + def _parse_json_key_value(self) -> t.Optional[exp.JSONKeyValue]: self._match_text_seq("KEY") key = self._parse_field() self._match(TokenType.COLON) self._match_text_seq("VALUE") value = self._parse_field() + if not key and not value: return None return self.expression(exp.JSONKeyValue, this=key, expression=value) - def _parse_json_object(self) -> exp.Expression: - expressions = self._parse_csv(self._parse_json_key_value) + def _parse_json_object(self) -> exp.JSONObject: + star = self._parse_star() + expressions = [star] if star else self._parse_csv(self._parse_json_key_value) null_handling = None if self._match_text_seq("NULL", "ON", "NULL"): @@ -3767,7 +3800,7 @@ class Parser(metaclass=_Parser): encoding=encoding, ) - def _parse_logarithm(self) -> exp.Expression: + def _parse_logarithm(self) -> exp.Func: # Default argument order is base, expression args = self._parse_csv(self._parse_range) @@ -3780,7 +3813,7 @@ class Parser(metaclass=_Parser): exp.Ln if self.LOG_DEFAULTS_TO_LN else exp.Log, this=seq_get(args, 0) ) - def _parse_match_against(self) -> exp.Expression: + def _parse_match_against(self) -> exp.MatchAgainst: expressions = self._parse_csv(self._parse_column) self._match_text_seq(")", "AGAINST", "(") @@ -3803,15 +3836,16 @@ class Parser(metaclass=_Parser): ) # https://learn.microsoft.com/en-us/sql/t-sql/functions/openjson-transact-sql?view=sql-server-ver16 - def _parse_open_json(self) -> exp.Expression: + def _parse_open_json(self) -> exp.OpenJSON: this = self._parse_bitwise() path = self._match(TokenType.COMMA) and self._parse_string() - def _parse_open_json_column_def() -> exp.Expression: + def _parse_open_json_column_def() -> exp.OpenJSONColumnDef: this = self._parse_field(any_token=True) kind = self._parse_types() path = self._parse_string() as_json = self._match_pair(TokenType.ALIAS, TokenType.JSON) + return self.expression( exp.OpenJSONColumnDef, this=this, kind=kind, path=path, as_json=as_json ) @@ -3823,7 +3857,7 @@ class Parser(metaclass=_Parser): return self.expression(exp.OpenJSON, this=this, path=path, expressions=expressions) - def _parse_position(self, haystack_first: bool = False) -> exp.Expression: + def _parse_position(self, haystack_first: bool = False) -> exp.StrPosition: args = self._parse_csv(self._parse_bitwise) if self._match(TokenType.IN): @@ -3838,17 +3872,15 @@ class Parser(metaclass=_Parser): needle = seq_get(args, 0) haystack = seq_get(args, 1) - this = exp.StrPosition(this=haystack, substr=needle, position=seq_get(args, 2)) - - self.validate_expression(this, args) - - return this + return self.expression( + exp.StrPosition, this=haystack, substr=needle, position=seq_get(args, 2) + ) - def _parse_join_hint(self, func_name: str) -> exp.Expression: + def _parse_join_hint(self, func_name: str) -> exp.JoinHint: args = self._parse_csv(self._parse_table) return exp.JoinHint(this=func_name.upper(), expressions=args) - def _parse_substring(self) -> exp.Expression: + def _parse_substring(self) -> exp.Substring: # Postgres supports the form: substring(string [from int] [for int]) # https://www.postgresql.org/docs/9.1/functions-string.html @ Table 9-6 @@ -3859,12 +3891,9 @@ class Parser(metaclass=_Parser): if self._match(TokenType.FOR): args.append(self._parse_bitwise()) - this = exp.Substring.from_arg_list(args) - self.validate_expression(this, args) - - return this + return self.validate_expression(exp.Substring.from_arg_list(args), args) - def _parse_trim(self) -> exp.Expression: + def _parse_trim(self) -> exp.Trim: # https://www.w3resource.com/sql/character-functions/trim.php # https://docs.oracle.com/javadb/10.8.3.0/ref/rreftrimfunc.html @@ -3885,11 +3914,7 @@ class Parser(metaclass=_Parser): collation = self._parse_bitwise() return self.expression( - exp.Trim, - this=this, - position=position, - expression=expression, - collation=collation, + exp.Trim, this=this, position=position, expression=expression, collation=collation ) def _parse_window_clause(self) -> t.Optional[t.List[t.Optional[exp.Expression]]]: @@ -4047,7 +4072,7 @@ class Parser(metaclass=_Parser): return self.PRIMARY_PARSERS[TokenType.STRING](self, self._prev) return self._parse_placeholder() - def _parse_string_as_identifier(self) -> t.Optional[exp.Expression]: + def _parse_string_as_identifier(self) -> t.Optional[exp.Identifier]: return exp.to_identifier(self._match(TokenType.STRING) and self._prev.text, quoted=True) def _parse_number(self) -> t.Optional[exp.Expression]: @@ -4097,7 +4122,7 @@ class Parser(metaclass=_Parser): return self.PRIMARY_PARSERS[TokenType.STAR](self, self._prev) return None - def _parse_parameter(self) -> exp.Expression: + def _parse_parameter(self) -> exp.Parameter: wrapped = self._match(TokenType.L_BRACE) this = self._parse_var() or self._parse_identifier() or self._parse_primary() self._match(TokenType.R_BRACE) @@ -4183,7 +4208,7 @@ class Parser(metaclass=_Parser): self._parse_set_operations(self._parse_select(nested=True, parse_subquery_alias=False)) ) - def _parse_transaction(self) -> exp.Expression: + def _parse_transaction(self) -> exp.Transaction: this = None if self._match_texts(self.TRANSACTION_KIND): this = self._prev.text @@ -4203,7 +4228,7 @@ class Parser(metaclass=_Parser): return self.expression(exp.Transaction, this=this, modes=modes) - def _parse_commit_or_rollback(self) -> exp.Expression: + def _parse_commit_or_rollback(self) -> exp.Commit | exp.Rollback: chain = None savepoint = None is_rollback = self._prev.token_type == TokenType.ROLLBACK @@ -4220,6 +4245,7 @@ class Parser(metaclass=_Parser): if is_rollback: return self.expression(exp.Rollback, savepoint=savepoint) + return self.expression(exp.Commit, chain=chain) def _parse_add_column(self) -> t.Optional[exp.Expression]: @@ -4243,19 +4269,19 @@ class Parser(metaclass=_Parser): return expression - def _parse_drop_column(self) -> t.Optional[exp.Expression]: + def _parse_drop_column(self) -> t.Optional[exp.Drop | exp.Command]: drop = self._match(TokenType.DROP) and self._parse_drop() if drop and not isinstance(drop, exp.Command): drop.set("kind", drop.args.get("kind", "COLUMN")) return drop # https://docs.aws.amazon.com/athena/latest/ug/alter-table-drop-partition.html - def _parse_drop_partition(self, exists: t.Optional[bool] = None) -> exp.Expression: + def _parse_drop_partition(self, exists: t.Optional[bool] = None) -> exp.DropPartition: return self.expression( exp.DropPartition, expressions=self._parse_csv(self._parse_partition), exists=exists ) - def _parse_add_constraint(self) -> t.Optional[exp.Expression]: + def _parse_add_constraint(self) -> exp.AddConstraint: this = None kind = self._prev.token_type @@ -4288,7 +4314,7 @@ class Parser(metaclass=_Parser): self._retreat(index) return self._parse_csv(self._parse_add_column) - def _parse_alter_table_alter(self) -> exp.Expression: + def _parse_alter_table_alter(self) -> exp.AlterColumn: self._match(TokenType.COLUMN) column = self._parse_field(any_token=True) @@ -4316,11 +4342,11 @@ class Parser(metaclass=_Parser): self._retreat(index) return self._parse_csv(self._parse_drop_column) - def _parse_alter_table_rename(self) -> exp.Expression: + def _parse_alter_table_rename(self) -> exp.RenameTable: self._match_text_seq("TO") return self.expression(exp.RenameTable, this=self._parse_table(schema=True)) - def _parse_alter(self) -> t.Optional[exp.Expression]: + def _parse_alter(self) -> exp.AlterTable | exp.Command: start = self._prev if not self._match(TokenType.TABLE): @@ -4345,7 +4371,7 @@ class Parser(metaclass=_Parser): ) return self._parse_as_command(start) - def _parse_merge(self) -> exp.Expression: + def _parse_merge(self) -> exp.Merge: self._match(TokenType.INTO) target = self._parse_table() @@ -4412,7 +4438,7 @@ class Parser(metaclass=_Parser): ) def _parse_show(self) -> t.Optional[exp.Expression]: - parser = self._find_parser(self.SHOW_PARSERS, self._show_trie) # type: ignore + parser = self._find_parser(self.SHOW_PARSERS, self.SHOW_TRIE) if parser: return parser(self) self._advance() @@ -4433,17 +4459,9 @@ class Parser(metaclass=_Parser): return None right = self._parse_statement() or self._parse_id_var() - this = self.expression( - exp.EQ, - this=left, - expression=right, - ) + this = self.expression(exp.EQ, this=left, expression=right) - return self.expression( - exp.SetItem, - this=this, - kind=kind, - ) + return self.expression(exp.SetItem, this=this, kind=kind) def _parse_set_transaction(self, global_: bool = False) -> exp.Expression: self._match_text_seq("TRANSACTION") @@ -4458,10 +4476,10 @@ class Parser(metaclass=_Parser): ) def _parse_set_item(self) -> t.Optional[exp.Expression]: - parser = self._find_parser(self.SET_PARSERS, self._set_trie) # type: ignore + parser = self._find_parser(self.SET_PARSERS, self.SET_TRIE) return parser(self) if parser else self._parse_set_item_assignment(kind=None) - def _parse_set(self) -> exp.Expression: + def _parse_set(self) -> exp.Set | exp.Command: index = self._index set_ = self.expression(exp.Set, expressions=self._parse_csv(self._parse_set_item)) @@ -4471,10 +4489,10 @@ class Parser(metaclass=_Parser): return set_ - def _parse_var_from_options(self, options: t.Collection[str]) -> t.Optional[exp.Expression]: + def _parse_var_from_options(self, options: t.Collection[str]) -> t.Optional[exp.Var]: for option in options: if self._match_text_seq(*option.split(" ")): - return exp.Var(this=option) + return exp.var(option) return None def _parse_as_command(self, start: Token) -> exp.Command: -- cgit v1.2.3