From 20739a12c39121a9e7ad3c9a2469ec5a6876199d Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Sat, 3 Jun 2023 01:59:40 +0200 Subject: Merging upstream version 15.0.0. Signed-off-by: Daniel Baumann --- sqlglot/parser.py | 896 ++++++++++++++++++++++++++++++++---------------------- 1 file changed, 525 insertions(+), 371 deletions(-) (limited to 'sqlglot/parser.py') diff --git a/sqlglot/parser.py b/sqlglot/parser.py index d8d9f88..e77bb5a 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -6,22 +6,17 @@ from collections import defaultdict from sqlglot import exp from sqlglot.errors import ErrorLevel, ParseError, concat_messages, merge_errors -from sqlglot.helper import ( - apply_index_offset, - count_params, - ensure_collection, - ensure_list, - seq_get, -) +from sqlglot.helper import apply_index_offset, ensure_collection, ensure_list, seq_get from sqlglot.tokens import Token, Tokenizer, TokenType from sqlglot.trie import in_trie, new_trie -logger = logging.getLogger("sqlglot") +if t.TYPE_CHECKING: + from sqlglot._typing import E -E = t.TypeVar("E", bound=exp.Expression) +logger = logging.getLogger("sqlglot") -def parse_var_map(args: t.Sequence) -> exp.Expression: +def parse_var_map(args: t.List) -> exp.StarMap | exp.VarMap: if len(args) == 1 and args[0].is_star: return exp.StarMap(this=args[0]) @@ -36,7 +31,7 @@ def parse_var_map(args: t.Sequence) -> exp.Expression: ) -def parse_like(args): +def parse_like(args: t.List) -> exp.Expression: like = exp.Like(this=seq_get(args, 1), expression=seq_get(args, 0)) return exp.Escape(this=like, expression=seq_get(args, 2)) if len(args) > 2 else like @@ -65,7 +60,7 @@ class Parser(metaclass=_Parser): Args: error_level: the desired error level. - Default: ErrorLevel.RAISE + Default: ErrorLevel.IMMEDIATE error_message_context: determines the amount of context to capture from a query string when displaying the error message (in number of characters). Default: 50. @@ -118,8 +113,8 @@ class Parser(metaclass=_Parser): NESTED_TYPE_TOKENS = { TokenType.ARRAY, TokenType.MAP, - TokenType.STRUCT, TokenType.NULLABLE, + TokenType.STRUCT, } TYPE_TOKENS = { @@ -158,6 +153,7 @@ class Parser(metaclass=_Parser): TokenType.TIMESTAMPTZ, TokenType.TIMESTAMPLTZ, TokenType.DATETIME, + TokenType.DATETIME64, TokenType.DATE, TokenType.DECIMAL, TokenType.BIGDECIMAL, @@ -211,20 +207,18 @@ class Parser(metaclass=_Parser): TokenType.VAR, TokenType.ANTI, TokenType.APPLY, + TokenType.ASC, TokenType.AUTO_INCREMENT, TokenType.BEGIN, - TokenType.BOTH, - TokenType.BUCKET, TokenType.CACHE, - TokenType.CASCADE, TokenType.COLLATE, TokenType.COMMAND, TokenType.COMMENT, TokenType.COMMIT, - TokenType.COMPOUND, TokenType.CONSTRAINT, TokenType.DEFAULT, TokenType.DELETE, + TokenType.DESC, TokenType.DESCRIBE, TokenType.DIV, TokenType.END, @@ -233,7 +227,6 @@ class Parser(metaclass=_Parser): TokenType.FALSE, TokenType.FIRST, TokenType.FILTER, - TokenType.FOLLOWING, TokenType.FORMAT, TokenType.FULL, TokenType.IF, @@ -241,41 +234,31 @@ class Parser(metaclass=_Parser): TokenType.ISNULL, TokenType.INTERVAL, TokenType.KEEP, - TokenType.LAZY, - TokenType.LEADING, TokenType.LEFT, - TokenType.LOCAL, - TokenType.MATERIALIZED, + TokenType.LOAD, TokenType.MERGE, TokenType.NATURAL, TokenType.NEXT, TokenType.OFFSET, - TokenType.ONLY, - TokenType.OPTIONS, TokenType.ORDINALITY, TokenType.OVERWRITE, TokenType.PARTITION, TokenType.PERCENT, TokenType.PIVOT, TokenType.PRAGMA, - TokenType.PRECEDING, TokenType.RANGE, TokenType.REFERENCES, TokenType.RIGHT, TokenType.ROW, TokenType.ROWS, - TokenType.SEED, TokenType.SEMI, TokenType.SET, + TokenType.SETTINGS, TokenType.SHOW, - TokenType.SORTKEY, TokenType.TEMPORARY, TokenType.TOP, - TokenType.TRAILING, TokenType.TRUE, - TokenType.UNBOUNDED, TokenType.UNIQUE, - TokenType.UNLOGGED, TokenType.UNPIVOT, TokenType.VOLATILE, TokenType.WINDOW, @@ -291,6 +274,7 @@ class Parser(metaclass=_Parser): TokenType.APPLY, TokenType.FULL, TokenType.LEFT, + TokenType.LOCK, TokenType.NATURAL, TokenType.OFFSET, TokenType.RIGHT, @@ -301,7 +285,7 @@ class Parser(metaclass=_Parser): UPDATE_ALIAS_TOKENS = TABLE_ALIAS_TOKENS - {TokenType.SET} - TRIM_TYPES = {TokenType.LEADING, TokenType.TRAILING, TokenType.BOTH} + TRIM_TYPES = {"LEADING", "TRAILING", "BOTH"} FUNC_TOKENS = { TokenType.COMMAND, @@ -322,6 +306,7 @@ class Parser(metaclass=_Parser): TokenType.MERGE, TokenType.OFFSET, TokenType.PRIMARY_KEY, + TokenType.RANGE, TokenType.REPLACE, TokenType.ROW, TokenType.UNNEST, @@ -455,31 +440,31 @@ class Parser(metaclass=_Parser): } EXPRESSION_PARSERS = { + exp.Cluster: lambda self: self._parse_sort(exp.Cluster, "CLUSTER", "BY"), exp.Column: lambda self: self._parse_column(), + exp.Condition: lambda self: self._parse_conjunction(), exp.DataType: lambda self: self._parse_types(), + exp.Expression: lambda self: self._parse_statement(), exp.From: lambda self: self._parse_from(), exp.Group: lambda self: self._parse_group(), + exp.Having: lambda self: self._parse_having(), exp.Identifier: lambda self: self._parse_id_var(), - exp.Lateral: lambda self: self._parse_lateral(), exp.Join: lambda self: self._parse_join(), - exp.Order: lambda self: self._parse_order(), - exp.Cluster: lambda self: self._parse_sort(TokenType.CLUSTER_BY, exp.Cluster), - exp.Sort: lambda self: self._parse_sort(TokenType.SORT_BY, exp.Sort), exp.Lambda: lambda self: self._parse_lambda(), + exp.Lateral: lambda self: self._parse_lateral(), exp.Limit: lambda self: self._parse_limit(), exp.Offset: lambda self: self._parse_offset(), - exp.TableAlias: lambda self: self._parse_table_alias(), - exp.Table: lambda self: self._parse_table(), - exp.Condition: lambda self: self._parse_conjunction(), - exp.Expression: lambda self: self._parse_statement(), - exp.Properties: lambda self: self._parse_properties(), - exp.Where: lambda self: self._parse_where(), + exp.Order: lambda self: self._parse_order(), exp.Ordered: lambda self: self._parse_ordered(), - exp.Having: lambda self: self._parse_having(), - exp.With: lambda self: self._parse_with(), - exp.Window: lambda self: self._parse_named_window(), + exp.Properties: lambda self: self._parse_properties(), exp.Qualify: lambda self: self._parse_qualify(), exp.Returning: lambda self: self._parse_returning(), + exp.Sort: lambda self: self._parse_sort(exp.Sort, "SORT", "BY"), + exp.Table: lambda self: self._parse_table_parts(), + exp.TableAlias: lambda self: self._parse_table_alias(), + exp.Where: lambda self: self._parse_where(), + exp.Window: lambda self: self._parse_named_window(), + exp.With: lambda self: self._parse_with(), "JOIN_TYPE": lambda self: self._parse_join_side_and_kind(), } @@ -495,9 +480,13 @@ class Parser(metaclass=_Parser): TokenType.DESCRIBE: lambda self: self._parse_describe(), TokenType.DROP: lambda self: self._parse_drop(), TokenType.END: lambda self: self._parse_commit_or_rollback(), + TokenType.FROM: lambda self: exp.select("*").from_( + t.cast(exp.From, self._parse_from(skip_from_token=True)) + ), TokenType.INSERT: lambda self: self._parse_insert(), - TokenType.LOAD_DATA: lambda self: self._parse_load_data(), + TokenType.LOAD: lambda self: self._parse_load(), TokenType.MERGE: lambda self: self._parse_merge(), + TokenType.PIVOT: lambda self: self._parse_simplified_pivot(), TokenType.PRAGMA: lambda self: self.expression(exp.Pragma, this=self._parse_expression()), TokenType.ROLLBACK: lambda self: self._parse_commit_or_rollback(), TokenType.SET: lambda self: self._parse_set(), @@ -536,7 +525,10 @@ class Parser(metaclass=_Parser): TokenType.HEX_STRING: lambda self, token: self.expression(exp.HexString, this=token.text), TokenType.BYTE_STRING: lambda self, token: self.expression(exp.ByteString, this=token.text), TokenType.INTRODUCER: lambda self, token: self._parse_introducer(token), - TokenType.NATIONAL: lambda self, token: self._parse_national(token), + TokenType.NATIONAL_STRING: lambda self, token: self.expression( + exp.National, this=token.text + ), + TokenType.RAW_STRING: lambda self, token: self.expression(exp.RawString, this=token.text), TokenType.SESSION_PARAMETER: lambda self, _: self._parse_session_parameter(), } @@ -551,91 +543,76 @@ class Parser(metaclass=_Parser): RANGE_PARSERS = { TokenType.BETWEEN: lambda self, this: self._parse_between(this), TokenType.GLOB: binary_range_parser(exp.Glob), - TokenType.OVERLAPS: binary_range_parser(exp.Overlaps), + TokenType.ILIKE: binary_range_parser(exp.ILike), TokenType.IN: lambda self, this: self._parse_in(this), + TokenType.IRLIKE: binary_range_parser(exp.RegexpILike), TokenType.IS: lambda self, this: self._parse_is(this), TokenType.LIKE: binary_range_parser(exp.Like), - TokenType.ILIKE: binary_range_parser(exp.ILike), - TokenType.IRLIKE: binary_range_parser(exp.RegexpILike), + TokenType.OVERLAPS: binary_range_parser(exp.Overlaps), TokenType.RLIKE: binary_range_parser(exp.RegexpLike), TokenType.SIMILAR_TO: binary_range_parser(exp.SimilarTo), } - PROPERTY_PARSERS = { - "AFTER": lambda self: self._parse_afterjournal( - no=self._prev.text.upper() == "NO", dual=self._prev.text.upper() == "DUAL" - ), + PROPERTY_PARSERS: t.Dict[str, t.Callable] = { "ALGORITHM": lambda self: self._parse_property_assignment(exp.AlgorithmProperty), "AUTO_INCREMENT": lambda self: self._parse_property_assignment(exp.AutoIncrementProperty), - "BEFORE": lambda self: self._parse_journal( - no=self._prev.text.upper() == "NO", dual=self._prev.text.upper() == "DUAL" - ), "BLOCKCOMPRESSION": lambda self: self._parse_blockcompression(), "CHARACTER SET": lambda self: self._parse_character_set(), "CHECKSUM": lambda self: self._parse_checksum(), - "CLUSTER BY": lambda self: self.expression( - exp.Cluster, expressions=self._parse_csv(self._parse_ordered) - ), + "CLUSTER": lambda self: self._parse_cluster(), "COLLATE": lambda self: self._parse_property_assignment(exp.CollateProperty), "COMMENT": lambda self: self._parse_property_assignment(exp.SchemaCommentProperty), - "DATABLOCKSIZE": lambda self: self._parse_datablocksize( - default=self._prev.text.upper() == "DEFAULT" - ), + "DATABLOCKSIZE": lambda self, **kwargs: self._parse_datablocksize(**kwargs), "DEFINER": lambda self: self._parse_definer(), "DETERMINISTIC": lambda self: self.expression( exp.StabilityProperty, this=exp.Literal.string("IMMUTABLE") ), "DISTKEY": lambda self: self._parse_distkey(), "DISTSTYLE": lambda self: self._parse_property_assignment(exp.DistStyleProperty), + "ENGINE": lambda self: self._parse_property_assignment(exp.EngineProperty), "EXECUTE": lambda self: self._parse_property_assignment(exp.ExecuteAsProperty), "EXTERNAL": lambda self: self.expression(exp.ExternalProperty), - "FALLBACK": lambda self: self._parse_fallback(no=self._prev.text.upper() == "NO"), + "FALLBACK": lambda self, **kwargs: self._parse_fallback(**kwargs), "FORMAT": lambda self: self._parse_property_assignment(exp.FileFormatProperty), "FREESPACE": lambda self: self._parse_freespace(), - "GLOBAL": lambda self: self._parse_temporary(global_=True), "IMMUTABLE": lambda self: self.expression( exp.StabilityProperty, this=exp.Literal.string("IMMUTABLE") ), - "JOURNAL": lambda self: self._parse_journal( - no=self._prev.text.upper() == "NO", dual=self._prev.text.upper() == "DUAL" - ), + "JOURNAL": lambda self, **kwargs: self._parse_journal(**kwargs), "LANGUAGE": lambda self: self._parse_property_assignment(exp.LanguageProperty), "LIKE": lambda self: self._parse_create_like(), - "LOCAL": lambda self: self._parse_afterjournal(no=False, dual=False, local=True), "LOCATION": lambda self: self._parse_property_assignment(exp.LocationProperty), "LOCK": lambda self: self._parse_locking(), "LOCKING": lambda self: self._parse_locking(), - "LOG": lambda self: self._parse_log(no=self._prev.text.upper() == "NO"), + "LOG": lambda self, **kwargs: self._parse_log(**kwargs), "MATERIALIZED": lambda self: self.expression(exp.MaterializedProperty), - "MAX": lambda self: self._parse_datablocksize(), - "MAXIMUM": lambda self: self._parse_datablocksize(), - "MERGEBLOCKRATIO": lambda self: self._parse_mergeblockratio( - no=self._prev.text.upper() == "NO", default=self._prev.text.upper() == "DEFAULT" - ), - "MIN": lambda self: self._parse_datablocksize(), - "MINIMUM": lambda self: self._parse_datablocksize(), + "MERGEBLOCKRATIO": lambda self, **kwargs: self._parse_mergeblockratio(**kwargs), "MULTISET": lambda self: self.expression(exp.SetProperty, multi=True), - "NO": lambda self: self._parse_noprimaryindex(), - "NOT": lambda self: self._parse_afterjournal(no=False, dual=False, local=False), - "ON": lambda self: self._parse_oncommit(), + "NO": lambda self: self._parse_no_property(), + "ON": lambda self: self._parse_on_property(), + "ORDER BY": lambda self: self._parse_order(skip_order_token=True), "PARTITION BY": lambda self: self._parse_partitioned_by(), "PARTITIONED BY": lambda self: self._parse_partitioned_by(), "PARTITIONED_BY": lambda self: self._parse_partitioned_by(), + "PRIMARY KEY": lambda self: self._parse_primary_key(), "RETURNS": lambda self: self._parse_returns(), "ROW": lambda self: self._parse_row(), "ROW_FORMAT": lambda self: self._parse_property_assignment(exp.RowFormatProperty), "SET": lambda self: self.expression(exp.SetProperty, multi=False), + "SETTINGS": lambda self: self.expression( + exp.SettingsProperty, expressions=self._parse_csv(self._parse_set_item) + ), "SORTKEY": lambda self: self._parse_sortkey(), "STABLE": lambda self: self.expression( exp.StabilityProperty, this=exp.Literal.string("STABLE") ), "STORED": lambda self: self._parse_stored(), - "TABLE_FORMAT": lambda self: self._parse_property_assignment(exp.TableFormatProperty), "TBLPROPERTIES": lambda self: self._parse_wrapped_csv(self._parse_property), - "TEMP": lambda self: self._parse_temporary(global_=False), - "TEMPORARY": lambda self: self._parse_temporary(global_=False), + "TEMP": lambda self: self.expression(exp.TemporaryProperty), + "TEMPORARY": lambda self: self.expression(exp.TemporaryProperty), "TRANSIENT": lambda self: self.expression(exp.TransientProperty), - "USING": lambda self: self._parse_property_assignment(exp.TableFormatProperty), + "TTL": lambda self: self._parse_ttl(), + "USING": lambda self: self._parse_property_assignment(exp.FileFormatProperty), "VOLATILE": lambda self: self._parse_volatile_property(), "WITH": lambda self: self._parse_with_property(), } @@ -679,6 +656,7 @@ class Parser(metaclass=_Parser): "TITLE": lambda self: self.expression( exp.TitleColumnConstraint, this=self._parse_var_or_string() ), + "TTL": lambda self: self.expression(exp.MergeTreeTTL, expressions=[self._parse_bitwise()]), "UNIQUE": lambda self: self._parse_unique(), "UPPERCASE": lambda self: self.expression(exp.UppercaseColumnConstraint), } @@ -704,6 +682,8 @@ class Parser(metaclass=_Parser): ), } + FUNCTIONS_WITH_ALIASED_ARGS = {"STRUCT"} + FUNCTION_PARSERS: t.Dict[str, t.Callable] = { "CAST": lambda self: self._parse_cast(self.STRICT_CAST), "CONVERT": lambda self: self._parse_convert(self.STRICT_CAST), @@ -712,7 +692,9 @@ class Parser(metaclass=_Parser): "JSON_OBJECT": lambda self: self._parse_json_object(), "LOG": lambda self: self._parse_logarithm(), "MATCH": lambda self: self._parse_match_against(), + "OPENJSON": lambda self: self._parse_open_json(), "POSITION": lambda self: self._parse_position(), + "SAFE_CAST": lambda self: self._parse_cast(False), "STRING_AGG": lambda self: self._parse_string_agg(), "SUBSTRING": lambda self: self._parse_substring(), "TRIM": lambda self: self._parse_trim(), @@ -721,19 +703,18 @@ class Parser(metaclass=_Parser): } QUERY_MODIFIER_PARSERS = { + "joins": lambda self: list(iter(self._parse_join, None)), + "laterals": lambda self: list(iter(self._parse_lateral, None)), "match": lambda self: self._parse_match_recognize(), "where": lambda self: self._parse_where(), "group": lambda self: self._parse_group(), "having": lambda self: self._parse_having(), "qualify": lambda self: self._parse_qualify(), "windows": lambda self: self._parse_window_clause(), - "distribute": lambda self: self._parse_sort(TokenType.DISTRIBUTE_BY, exp.Distribute), - "sort": lambda self: self._parse_sort(TokenType.SORT_BY, exp.Sort), - "cluster": lambda self: self._parse_sort(TokenType.CLUSTER_BY, exp.Cluster), "order": lambda self: self._parse_order(), "limit": lambda self: self._parse_limit(), "offset": lambda self: self._parse_offset(), - "lock": lambda self: self._parse_lock(), + "locks": lambda self: self._parse_locks(), "sample": lambda self: self._parse_table_sample(as_modifier=True), } @@ -763,8 +744,11 @@ class Parser(metaclass=_Parser): INSERT_ALTERNATIVES = {"ABORT", "FAIL", "IGNORE", "REPLACE", "ROLLBACK"} + CLONE_KINDS = {"TIMESTAMP", "OFFSET", "STATEMENT"} + WINDOW_ALIAS_TOKENS = ID_VAR_TOKENS - {TokenType.ROWS} WINDOW_BEFORE_PAREN_TOKENS = {TokenType.OVER} + WINDOW_SIDES = {"FOLLOWING", "PRECEDING"} ADD_CONSTRAINT_TOKENS = {TokenType.CONSTRAINT, TokenType.PRIMARY_KEY, TokenType.FOREIGN_KEY} @@ -772,8 +756,8 @@ class Parser(metaclass=_Parser): CONVERT_TYPE_FIRST = False - QUOTED_PIVOT_COLUMNS: t.Optional[bool] = None PREFIXED_PIVOT_COLUMNS = False + IDENTIFY_PIVOT_STRINGS = False LOG_BASE_FIRST = True LOG_DEFAULTS_TO_LN = False @@ -875,7 +859,7 @@ class Parser(metaclass=_Parser): e.errors[0]["into_expression"] = expression_type errors.append(e) raise ParseError( - f"Failed to parse into {expression_types}", + f"Failed to parse '{sql or raw_tokens}' into {expression_types}", errors=merge_errors(errors), ) from errors[-1] @@ -933,7 +917,7 @@ class Parser(metaclass=_Parser): """ token = token or self._curr or self._prev or Token.string("") start = token.start - end = token.end + end = token.end + 1 start_context = self.sql[max(start - self.error_message_context, 0) : start] highlight = self.sql[start:end] end_context = self.sql[end : end + self.error_message_context] @@ -996,7 +980,7 @@ class Parser(metaclass=_Parser): self.raise_error(error_message) def _find_sql(self, start: Token, end: Token) -> str: - return self.sql[start.start : end.end] + return self.sql[start.start : end.end + 1] def _advance(self, times: int = 1) -> None: self._index += times @@ -1042,6 +1026,44 @@ class Parser(metaclass=_Parser): exp.Comment, this=this, kind=kind.text, expression=self._parse_string(), exists=exists ) + # https://clickhouse.com/docs/en/engines/table-engines/mergetree-family/mergetree#mergetree-table-ttl + def _parse_ttl(self) -> exp.Expression: + def _parse_ttl_action() -> t.Optional[exp.Expression]: + this = self._parse_bitwise() + + if self._match_text_seq("DELETE"): + return self.expression(exp.MergeTreeTTLAction, this=this, delete=True) + if self._match_text_seq("RECOMPRESS"): + return self.expression( + exp.MergeTreeTTLAction, this=this, recompress=self._parse_bitwise() + ) + if self._match_text_seq("TO", "DISK"): + return self.expression( + exp.MergeTreeTTLAction, this=this, to_disk=self._parse_string() + ) + if self._match_text_seq("TO", "VOLUME"): + return self.expression( + exp.MergeTreeTTLAction, this=this, to_volume=self._parse_string() + ) + + return this + + expressions = self._parse_csv(_parse_ttl_action) + where = self._parse_where() + group = self._parse_group() + + aggregates = None + if group and self._match(TokenType.SET): + aggregates = self._parse_csv(self._parse_set_item) + + return self.expression( + exp.MergeTreeTTL, + expressions=expressions, + where=where, + group=group, + aggregates=aggregates, + ) + def _parse_statement(self) -> t.Optional[exp.Expression]: if self._curr is None: return None @@ -1054,14 +1076,12 @@ class Parser(metaclass=_Parser): expression = self._parse_expression() expression = self._parse_set_operations(expression) if expression else self._parse_select() - - self._parse_query_modifiers(expression) - return expression + return self._parse_query_modifiers(expression) def _parse_drop(self) -> t.Optional[exp.Drop | exp.Command]: start = self._prev temporary = self._match(TokenType.TEMPORARY) - materialized = self._match(TokenType.MATERIALIZED) + materialized = self._match_text_seq("MATERIALIZED") kind = self._match_set(self.CREATABLES) and self._prev.text if not kind: return self._parse_as_command(start) @@ -1073,7 +1093,7 @@ class Parser(metaclass=_Parser): kind=kind, temporary=temporary, materialized=materialized, - cascade=self._match(TokenType.CASCADE), + cascade=self._match_text_seq("CASCADE"), constraints=self._match_text_seq("CONSTRAINTS"), purge=self._match_text_seq("PURGE"), ) @@ -1111,6 +1131,7 @@ class Parser(metaclass=_Parser): indexes = None no_schema_binding = None begin = None + clone = None if create_token.token_type in (TokenType.FUNCTION, TokenType.PROCEDURE): this = self._parse_user_defined_function(kind=create_token.token_type) @@ -1128,7 +1149,7 @@ class Parser(metaclass=_Parser): if return_: expression = self.expression(exp.Return, this=expression) elif create_token.token_type == TokenType.INDEX: - this = self._parse_index() + this = self._parse_index(index=self._parse_id_var()) elif create_token.token_type in self.DB_CREATABLES: table_parts = self._parse_table_parts(schema=True) @@ -1166,33 +1187,40 @@ class Parser(metaclass=_Parser): expression = self._parse_ddl_select() if create_token.token_type == TokenType.TABLE: - # exp.Properties.Location.POST_EXPRESSION - temp_properties = self._parse_properties() - if properties and temp_properties: - properties.expressions.extend(temp_properties.expressions) - elif temp_properties: - properties = temp_properties - indexes = [] while True: - index = self._parse_create_table_index() + index = self._parse_index() - # exp.Properties.Location.POST_INDEX - if self._match(TokenType.PARTITION_BY, advance=False): - temp_properties = self._parse_properties() - if properties and temp_properties: - properties.expressions.extend(temp_properties.expressions) - elif temp_properties: - properties = temp_properties + # exp.Properties.Location.POST_EXPRESSION or exp.Properties.Location.POST_INDEX + temp_properties = self._parse_properties() + if properties and temp_properties: + properties.expressions.extend(temp_properties.expressions) + elif temp_properties: + properties = temp_properties if not index: break else: + self._match(TokenType.COMMA) indexes.append(index) elif create_token.token_type == TokenType.VIEW: if self._match_text_seq("WITH", "NO", "SCHEMA", "BINDING"): no_schema_binding = True + if self._match_text_seq("CLONE"): + clone = self._parse_table(schema=True) + when = self._match_texts({"AT", "BEFORE"}) and self._prev.text.upper() + clone_kind = ( + self._match(TokenType.L_PAREN) + and self._match_texts(self.CLONE_KINDS) + and self._prev.text.upper() + ) + clone_expression = self._match(TokenType.FARROW) and self._parse_bitwise() + self._match(TokenType.R_PAREN) + clone = self.expression( + exp.Clone, this=clone, when=when, kind=clone_kind, expression=clone_expression + ) + return self.expression( exp.Create, this=this, @@ -1205,18 +1233,31 @@ class Parser(metaclass=_Parser): indexes=indexes, no_schema_binding=no_schema_binding, begin=begin, + clone=clone, ) def _parse_property_before(self) -> t.Optional[exp.Expression]: + # only used for teradata currently self._match(TokenType.COMMA) - # parsers look to _prev for no/dual/default, so need to consume first - self._match_text_seq("NO") - self._match_text_seq("DUAL") - self._match_text_seq("DEFAULT") + kwargs = { + "no": self._match_text_seq("NO"), + "dual": self._match_text_seq("DUAL"), + "before": self._match_text_seq("BEFORE"), + "default": self._match_text_seq("DEFAULT"), + "local": (self._match_text_seq("LOCAL") and "LOCAL") + or (self._match_text_seq("NOT", "LOCAL") and "NOT LOCAL"), + "after": self._match_text_seq("AFTER"), + "minimum": self._match_texts(("MIN", "MINIMUM")), + "maximum": self._match_texts(("MAX", "MAXIMUM")), + } - if self.PROPERTY_PARSERS.get(self._curr.text.upper()): - return self.PROPERTY_PARSERS[self._curr.text.upper()](self) + if self._match_texts(self.PROPERTY_PARSERS): + parser = self.PROPERTY_PARSERS[self._prev.text.upper()] + try: + return parser(self, **{k: v for k, v in kwargs.items() if v}) + except TypeError: + self.raise_error(f"Cannot parse property '{self._prev.text}'") return None @@ -1227,7 +1268,7 @@ class Parser(metaclass=_Parser): if self._match_pair(TokenType.DEFAULT, TokenType.CHARACTER_SET): return self._parse_character_set(default=True) - if self._match_pair(TokenType.COMPOUND, TokenType.SORTKEY): + if self._match_text_seq("COMPOUND", "SORTKEY"): return self._parse_sortkey(compound=True) if self._match_text_seq("SQL", "SECURITY"): @@ -1262,23 +1303,20 @@ class Parser(metaclass=_Parser): def _parse_property_assignment(self, exp_class: t.Type[exp.Expression]) -> exp.Expression: self._match(TokenType.EQ) self._match(TokenType.ALIAS) - return self.expression( - exp_class, - this=self._parse_var_or_string() or self._parse_number() or self._parse_id_var(), - ) + return self.expression(exp_class, this=self._parse_field()) - def _parse_properties(self, before=None) -> t.Optional[exp.Expression]: + def _parse_properties(self, before: t.Optional[bool] = None) -> t.Optional[exp.Expression]: properties = [] while True: if before: - identified_property = self._parse_property_before() + prop = self._parse_property_before() else: - identified_property = self._parse_property() + prop = self._parse_property() - if not identified_property: + if not prop: break - for p in ensure_list(identified_property): + for p in ensure_list(prop): properties.append(p) if properties: @@ -1286,8 +1324,7 @@ class Parser(metaclass=_Parser): return None - def _parse_fallback(self, no=False) -> exp.Expression: - self._match_text_seq("FALLBACK") + def _parse_fallback(self, no: bool = False) -> exp.Expression: return self.expression( exp.FallbackProperty, no=no, protection=self._match_text_seq("PROTECTION") ) @@ -1345,23 +1382,13 @@ class Parser(metaclass=_Parser): self._match(TokenType.EQ) return self.expression(exp.WithJournalTableProperty, this=self._parse_table_parts()) - def _parse_log(self, no=False) -> exp.Expression: - self._match_text_seq("LOG") + def _parse_log(self, no: bool = False) -> exp.Expression: return self.expression(exp.LogProperty, no=no) - def _parse_journal(self, no=False, dual=False) -> exp.Expression: - before = self._match_text_seq("BEFORE") - self._match_text_seq("JOURNAL") - return self.expression(exp.JournalProperty, no=no, dual=dual, before=before) - - def _parse_afterjournal(self, no=False, dual=False, local=None) -> exp.Expression: - self._match_text_seq("NOT") - self._match_text_seq("LOCAL") - self._match_text_seq("AFTER", "JOURNAL") - return self.expression(exp.AfterJournalProperty, no=no, dual=dual, local=local) + def _parse_journal(self, **kwargs) -> exp.Expression: + return self.expression(exp.JournalProperty, **kwargs) def _parse_checksum(self) -> exp.Expression: - self._match_text_seq("CHECKSUM") self._match(TokenType.EQ) on = None @@ -1377,49 +1404,55 @@ class Parser(metaclass=_Parser): default=default, ) + def _parse_cluster(self) -> t.Optional[exp.Expression]: + if not self._match_text_seq("BY"): + self._retreat(self._index - 1) + return None + return self.expression( + exp.Cluster, + expressions=self._parse_csv(self._parse_ordered), + ) + def _parse_freespace(self) -> exp.Expression: - self._match_text_seq("FREESPACE") self._match(TokenType.EQ) return self.expression( exp.FreespaceProperty, this=self._parse_number(), percent=self._match(TokenType.PERCENT) ) - def _parse_mergeblockratio(self, no=False, default=False) -> exp.Expression: - self._match_text_seq("MERGEBLOCKRATIO") + def _parse_mergeblockratio(self, no: bool = False, default: bool = False) -> exp.Expression: if self._match(TokenType.EQ): return self.expression( exp.MergeBlockRatioProperty, this=self._parse_number(), percent=self._match(TokenType.PERCENT), ) - else: - return self.expression( - exp.MergeBlockRatioProperty, - no=no, - default=default, - ) + return self.expression( + exp.MergeBlockRatioProperty, + no=no, + default=default, + ) - def _parse_datablocksize(self, default=None) -> exp.Expression: - if default: - self._match_text_seq("DATABLOCKSIZE") - return self.expression(exp.DataBlocksizeProperty, default=True) - elif self._match_texts(("MIN", "MINIMUM")): - self._match_text_seq("DATABLOCKSIZE") - return self.expression(exp.DataBlocksizeProperty, min=True) - elif self._match_texts(("MAX", "MAXIMUM")): - self._match_text_seq("DATABLOCKSIZE") - return self.expression(exp.DataBlocksizeProperty, min=False) - - self._match_text_seq("DATABLOCKSIZE") + def _parse_datablocksize( + self, + default: t.Optional[bool] = None, + minimum: t.Optional[bool] = None, + maximum: t.Optional[bool] = None, + ) -> exp.Expression: self._match(TokenType.EQ) size = self._parse_number() units = None if self._match_texts(("BYTES", "KBYTES", "KILOBYTES")): units = self._prev.text - return self.expression(exp.DataBlocksizeProperty, size=size, units=units) + return self.expression( + exp.DataBlocksizeProperty, + size=size, + units=units, + default=default, + minimum=minimum, + maximum=maximum, + ) def _parse_blockcompression(self) -> exp.Expression: - self._match_text_seq("BLOCKCOMPRESSION") self._match(TokenType.EQ) always = self._match_text_seq("ALWAYS") manual = self._match_text_seq("MANUAL") @@ -1516,7 +1549,7 @@ class Parser(metaclass=_Parser): this=self._parse_schema() or self._parse_bracket(self._parse_field()), ) - def _parse_withdata(self, no=False) -> exp.Expression: + def _parse_withdata(self, no: bool = False) -> exp.Expression: if self._match_text_seq("AND", "STATISTICS"): statistics = True elif self._match_text_seq("AND", "NO", "STATISTICS"): @@ -1526,13 +1559,17 @@ class Parser(metaclass=_Parser): return self.expression(exp.WithDataProperty, no=no, statistics=statistics) - def _parse_noprimaryindex(self) -> exp.Expression: - self._match_text_seq("PRIMARY", "INDEX") - return exp.NoPrimaryIndexProperty() + def _parse_no_property(self) -> t.Optional[exp.Property]: + if self._match_text_seq("PRIMARY", "INDEX"): + return exp.NoPrimaryIndexProperty() + return None - def _parse_oncommit(self) -> exp.Expression: - self._match_text_seq("COMMIT", "PRESERVE", "ROWS") - return exp.OnCommitProperty() + def _parse_on_property(self) -> t.Optional[exp.Property]: + if self._match_text_seq("COMMIT", "PRESERVE", "ROWS"): + return exp.OnCommitProperty() + elif self._match_text_seq("COMMIT", "DELETE", "ROWS"): + return exp.OnCommitProperty(delete=True) + return None def _parse_distkey(self) -> exp.Expression: return self.expression(exp.DistKeyProperty, this=self._parse_wrapped(self._parse_id_var)) @@ -1587,10 +1624,6 @@ class Parser(metaclass=_Parser): return self.expression(exp.ReturnsProperty, this=value, is_table=is_table) - def _parse_temporary(self, global_=False) -> exp.Expression: - self._match(TokenType.TEMPORARY) # in case calling from "GLOBAL" - return self.expression(exp.TemporaryProperty, global_=global_) - def _parse_describe(self) -> exp.Expression: kind = self._match_set(self.CREATABLES) and self._prev.text this = self._parse_table() @@ -1599,7 +1632,7 @@ class Parser(metaclass=_Parser): def _parse_insert(self) -> exp.Expression: overwrite = self._match(TokenType.OVERWRITE) - local = self._match(TokenType.LOCAL) + local = self._match_text_seq("LOCAL") alternative = None if self._match_text_seq("DIRECTORY"): @@ -1700,23 +1733,25 @@ class Parser(metaclass=_Parser): return self.expression(exp.RowFormatDelimitedProperty, **kwargs) # type: ignore - def _parse_load_data(self) -> exp.Expression: - local = self._match(TokenType.LOCAL) - self._match_text_seq("INPATH") - inpath = self._parse_string() - overwrite = self._match(TokenType.OVERWRITE) - self._match_pair(TokenType.INTO, TokenType.TABLE) + def _parse_load(self) -> exp.Expression: + if self._match_text_seq("DATA"): + local = self._match_text_seq("LOCAL") + self._match_text_seq("INPATH") + inpath = self._parse_string() + overwrite = self._match(TokenType.OVERWRITE) + self._match_pair(TokenType.INTO, TokenType.TABLE) - return self.expression( - exp.LoadData, - this=self._parse_table(schema=True), - local=local, - overwrite=overwrite, - inpath=inpath, - partition=self._parse_partition(), - input_format=self._match_text_seq("INPUTFORMAT") and self._parse_string(), - serde=self._match_text_seq("SERDE") and self._parse_string(), - ) + return self.expression( + exp.LoadData, + this=self._parse_table(schema=True), + local=local, + overwrite=overwrite, + inpath=inpath, + partition=self._parse_partition(), + input_format=self._match_text_seq("INPUTFORMAT") and self._parse_string(), + serde=self._match_text_seq("SERDE") and self._parse_string(), + ) + return self._parse_as_command(self._prev) def _parse_delete(self) -> exp.Expression: self._match(TokenType.FROM) @@ -1735,7 +1770,7 @@ class Parser(metaclass=_Parser): **{ # type: ignore "this": self._parse_table(alias_tokens=self.UPDATE_ALIAS_TOKENS), "expressions": self._match(TokenType.SET) and self._parse_csv(self._parse_equality), - "from": self._parse_from(), + "from": self._parse_from(modifiers=True), "where": self._parse_where(), "returning": self._parse_returning(), }, @@ -1752,12 +1787,12 @@ class Parser(metaclass=_Parser): ) def _parse_cache(self) -> exp.Expression: - lazy = self._match(TokenType.LAZY) + lazy = self._match_text_seq("LAZY") self._match(TokenType.TABLE) table = self._parse_table(schema=True) options = [] - if self._match(TokenType.OPTIONS): + if self._match_text_seq("OPTIONS"): self._match_l_paren() k = self._parse_string() self._match(TokenType.EQ) @@ -1851,11 +1886,10 @@ class Parser(metaclass=_Parser): if from_: this.set("from", from_) - self._parse_query_modifiers(this) + this = self._parse_query_modifiers(this) elif (table or nested) and self._match(TokenType.L_PAREN): this = self._parse_table() if table else self._parse_select(nested=True) - self._parse_query_modifiers(this) - this = self._parse_set_operations(this) + this = self._parse_set_operations(self._parse_query_modifiers(this)) self._match_r_paren() # early return so that subquery unions aren't parsed again @@ -1868,6 +1902,10 @@ class Parser(metaclass=_Parser): expressions=self._parse_csv(self._parse_value), alias=self._parse_table_alias(), ) + elif self._match(TokenType.PIVOT): + this = self._parse_simplified_pivot() + elif self._match(TokenType.FROM): + this = exp.select("*").from_(t.cast(exp.From, self._parse_from(skip_from_token=True))) else: this = None @@ -1929,7 +1967,9 @@ class Parser(metaclass=_Parser): def _parse_subquery( self, this: t.Optional[exp.Expression], parse_alias: bool = True - ) -> exp.Expression: + ) -> t.Optional[exp.Expression]: + if not this: + return None return self.expression( exp.Subquery, this=this, @@ -1937,35 +1977,16 @@ class Parser(metaclass=_Parser): alias=self._parse_table_alias() if parse_alias else None, ) - def _parse_query_modifiers(self, this: t.Optional[exp.Expression]) -> None: - if not isinstance(this, self.MODIFIABLES): - return - - table = isinstance(this, exp.Table) - - while True: - join = self._parse_join() - if join: - this.append("joins", join) - - lateral = None - if not join: - lateral = self._parse_lateral() - if lateral: - this.append("laterals", lateral) - - comma = None if table else self._match(TokenType.COMMA) - if comma: - this.args["from"].append("expressions", self._parse_table()) - - if not (lateral or join or comma): - break - - for key, parser in self.QUERY_MODIFIER_PARSERS.items(): - expression = parser(self) + def _parse_query_modifiers( + self, this: t.Optional[exp.Expression] + ) -> t.Optional[exp.Expression]: + if isinstance(this, self.MODIFIABLES): + for key, parser in self.QUERY_MODIFIER_PARSERS.items(): + expression = parser(self) - if expression: - this.set(key, expression) + if expression: + this.set(key, expression) + return this def _parse_hint(self) -> t.Optional[exp.Expression]: if self._match(TokenType.HINT): @@ -1981,19 +2002,26 @@ class Parser(metaclass=_Parser): return None temp = self._match(TokenType.TEMPORARY) - unlogged = self._match(TokenType.UNLOGGED) + unlogged = self._match_text_seq("UNLOGGED") self._match(TokenType.TABLE) return self.expression( exp.Into, this=self._parse_table(schema=True), temporary=temp, unlogged=unlogged ) - def _parse_from(self) -> t.Optional[exp.Expression]: - if not self._match(TokenType.FROM): + def _parse_from( + self, modifiers: bool = False, skip_from_token: bool = False + ) -> t.Optional[exp.From]: + if not skip_from_token and not self._match(TokenType.FROM): return None + comments = self._prev_comments + this = self._parse_table() + return self.expression( - exp.From, comments=self._prev_comments, expressions=self._parse_csv(self._parse_table) + exp.From, + comments=comments, + this=self._parse_query_modifiers(this) if modifiers else this, ) def _parse_match_recognize(self) -> t.Optional[exp.Expression]: @@ -2136,6 +2164,9 @@ class Parser(metaclass=_Parser): ) def _parse_join(self, skip_join_token: bool = False) -> t.Optional[exp.Expression]: + if self._match(TokenType.COMMA): + return self.expression(exp.Join, this=self._parse_table()) + index = self._index natural, side, kind = self._parse_join_side_and_kind() hint = self._prev.text if self._match_texts(self.JOIN_HINTS) else None @@ -2176,55 +2207,66 @@ class Parser(metaclass=_Parser): return self.expression(exp.Join, **kwargs) # type: ignore - def _parse_index(self) -> exp.Expression: - index = self._parse_id_var() - self._match(TokenType.ON) - self._match(TokenType.TABLE) # hive + def _parse_index( + self, + index: t.Optional[exp.Expression] = None, + ) -> t.Optional[exp.Expression]: + if index: + unique = None + primary = None + amp = None - return self.expression( - exp.Index, - this=index, - table=self.expression(exp.Table, this=self._parse_id_var()), - columns=self._parse_expression(), - ) + self._match(TokenType.ON) + self._match(TokenType.TABLE) # hive + table = self._parse_table_parts(schema=True) + else: + unique = self._match(TokenType.UNIQUE) + primary = self._match_text_seq("PRIMARY") + amp = self._match_text_seq("AMP") + if not self._match(TokenType.INDEX): + return None + index = self._parse_id_var() + table = None - def _parse_create_table_index(self) -> t.Optional[exp.Expression]: - unique = self._match(TokenType.UNIQUE) - primary = self._match_text_seq("PRIMARY") - amp = self._match_text_seq("AMP") - if not self._match(TokenType.INDEX): - return None - index = self._parse_id_var() - columns = None if self._match(TokenType.L_PAREN, advance=False): - columns = self._parse_wrapped_csv(self._parse_column) + columns = self._parse_wrapped_csv(self._parse_ordered) + else: + columns = None + return self.expression( exp.Index, this=index, + table=table, columns=columns, unique=unique, primary=primary, amp=amp, + partition_by=self._parse_partition_by(), ) - def _parse_table_parts(self, schema: bool = False) -> exp.Expression: - catalog = None - db = None - - table = ( + def _parse_table_part(self, schema: bool = False) -> t.Optional[exp.Expression]: + return ( (not schema and self._parse_function()) or self._parse_id_var(any_token=False) or self._parse_string_as_identifier() + or self._parse_placeholder() ) + def _parse_table_parts(self, schema: bool = False) -> exp.Table: + catalog = None + db = None + table = self._parse_table_part(schema=schema) + while self._match(TokenType.DOT): if catalog: # This allows nesting the table in arbitrarily many dot expressions if needed - table = self.expression(exp.Dot, this=table, expression=self._parse_id_var()) + table = self.expression( + exp.Dot, this=table, expression=self._parse_table_part(schema=schema) + ) else: catalog = db db = table - table = self._parse_id_var() + table = self._parse_table_part(schema=schema) if not table: self.raise_error(f"Expected table name but got {self._curr}") @@ -2237,28 +2279,24 @@ class Parser(metaclass=_Parser): self, schema: bool = False, alias_tokens: t.Optional[t.Collection[TokenType]] = None ) -> t.Optional[exp.Expression]: lateral = self._parse_lateral() - if lateral: return lateral unnest = self._parse_unnest() - if unnest: return unnest values = self._parse_derived_table_values() - if values: return values subquery = self._parse_select(table=True) - if subquery: if not subquery.args.get("pivots"): subquery.set("pivots", self._parse_pivots()) return subquery - this = self._parse_table_parts(schema=schema) + this: exp.Expression = self._parse_table_parts(schema=schema) if schema: return self._parse_schema(this=this) @@ -2267,7 +2305,6 @@ class Parser(metaclass=_Parser): table_sample = self._parse_table_sample() alias = self._parse_table_alias(alias_tokens=alias_tokens or self.TABLE_ALIAS_TOKENS) - if alias: this.set("alias", alias) @@ -2352,9 +2389,9 @@ class Parser(metaclass=_Parser): num = self._parse_number() - if self._match(TokenType.BUCKET): + if self._match_text_seq("BUCKET"): bucket_numerator = self._parse_number() - self._match(TokenType.OUT_OF) + self._match_text_seq("OUT", "OF") bucket_denominator = bucket_denominator = self._parse_number() self._match(TokenType.ON) bucket_field = self._parse_field() @@ -2390,6 +2427,22 @@ class Parser(metaclass=_Parser): def _parse_pivots(self) -> t.List[t.Optional[exp.Expression]]: return list(iter(self._parse_pivot, None)) + # https://duckdb.org/docs/sql/statements/pivot + def _parse_simplified_pivot(self) -> exp.Pivot: + def _parse_on() -> t.Optional[exp.Expression]: + this = self._parse_bitwise() + return self._parse_in(this) if self._match(TokenType.IN) else this + + this = self._parse_table() + expressions = self._match(TokenType.ON) and self._parse_csv(_parse_on) + using = self._match(TokenType.USING) and self._parse_csv( + lambda: self._parse_alias(self._parse_function()) + ) + group = self._parse_group() + return self.expression( + exp.Pivot, this=this, expressions=expressions, using=using, group=group + ) + def _parse_pivot(self) -> t.Optional[exp.Expression]: index = self._index @@ -2423,7 +2476,7 @@ class Parser(metaclass=_Parser): if not self._match(TokenType.IN): self.raise_error("Expecting IN") - field = self._parse_in(value) + field = self._parse_in(value, alias=True) self._match_r_paren() @@ -2436,21 +2489,22 @@ class Parser(metaclass=_Parser): names = self._pivot_column_names(t.cast(t.List[exp.Expression], expressions)) columns: t.List[exp.Expression] = [] - for col in pivot.args["field"].expressions: + for fld in pivot.args["field"].expressions: + field_name = fld.sql() if self.IDENTIFY_PIVOT_STRINGS else fld.alias_or_name for name in names: if self.PREFIXED_PIVOT_COLUMNS: - name = f"{name}_{col.alias_or_name}" if name else col.alias_or_name + name = f"{name}_{field_name}" if name else field_name else: - name = f"{col.alias_or_name}_{name}" if name else col.alias_or_name + name = f"{field_name}_{name}" if name else field_name - columns.append(exp.to_identifier(name, quoted=self.QUOTED_PIVOT_COLUMNS)) + columns.append(exp.to_identifier(name)) pivot.set("columns", columns) return pivot - def _pivot_column_names(self, pivot_columns: t.List[exp.Expression]) -> t.List[str]: - return [agg.alias for agg in pivot_columns] + def _pivot_column_names(self, aggregations: t.List[exp.Expression]) -> t.List[str]: + return [agg.alias for agg in aggregations] def _parse_where(self, skip_where_token: bool = False) -> t.Optional[exp.Expression]: if not skip_where_token and not self._match(TokenType.WHERE): @@ -2477,6 +2531,7 @@ class Parser(metaclass=_Parser): rollup = None cube = None + totals = None with_ = self._match(TokenType.WITH) if self._match(TokenType.ROLLUP): @@ -2487,7 +2542,11 @@ class Parser(metaclass=_Parser): cube = with_ or self._parse_wrapped_csv(self._parse_column) elements["cube"].extend(ensure_list(cube)) - if not (expressions or grouping_sets or rollup or cube): + if self._match_text_seq("TOTALS"): + totals = True + elements["totals"] = True # type: ignore + + if not (grouping_sets or rollup or cube or totals): break return self.expression(exp.Group, **elements) # type: ignore @@ -2527,9 +2586,9 @@ class Parser(metaclass=_Parser): ) def _parse_sort( - self, token_type: TokenType, exp_class: t.Type[exp.Expression] + self, exp_class: t.Type[exp.Expression], *texts: str ) -> t.Optional[exp.Expression]: - if not self._match(token_type): + if not self._match_text_seq(*texts): return None return self.expression(exp_class, expressions=self._parse_csv(self._parse_ordered)) @@ -2537,8 +2596,8 @@ class Parser(metaclass=_Parser): this = self._parse_conjunction() self._match(TokenType.ASC) is_desc = self._match(TokenType.DESC) - is_nulls_first = self._match(TokenType.NULLS_FIRST) - is_nulls_last = self._match(TokenType.NULLS_LAST) + is_nulls_first = self._match_text_seq("NULLS", "FIRST") + is_nulls_last = self._match_text_seq("NULLS", "LAST") desc = is_desc or False asc = not desc nulls_first = is_nulls_first or False @@ -2578,7 +2637,7 @@ class Parser(metaclass=_Parser): self._match_set((TokenType.ROW, TokenType.ROWS)) - only = self._match(TokenType.ONLY) + only = self._match_text_seq("ONLY") with_ties = self._match_text_seq("WITH", "TIES") if only and with_ties: @@ -2602,13 +2661,37 @@ class Parser(metaclass=_Parser): self._match_set((TokenType.ROW, TokenType.ROWS)) return self.expression(exp.Offset, this=this, expression=count) - def _parse_lock(self) -> t.Optional[exp.Expression]: - if self._match_text_seq("FOR", "UPDATE"): - return self.expression(exp.Lock, update=True) - if self._match_text_seq("FOR", "SHARE"): - return self.expression(exp.Lock, update=False) + def _parse_locks(self) -> t.List[exp.Expression]: + # Lists are invariant, so we need to use a type hint here + locks: t.List[exp.Expression] = [] - return None + while True: + if self._match_text_seq("FOR", "UPDATE"): + update = True + elif self._match_text_seq("FOR", "SHARE") or self._match_text_seq( + "LOCK", "IN", "SHARE", "MODE" + ): + update = False + else: + break + + expressions = None + if self._match_text_seq("OF"): + expressions = self._parse_csv(lambda: self._parse_table(schema=True)) + + wait: t.Optional[bool | exp.Expression] = None + if self._match_text_seq("NOWAIT"): + wait = True + elif self._match_text_seq("WAIT"): + wait = self._parse_primary() + elif self._match_text_seq("SKIP", "LOCKED"): + wait = False + + locks.append( + self.expression(exp.Lock, update=update, expressions=expressions, wait=wait) + ) + + return locks def _parse_set_operations(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: if not self._match_set(self.SET_OPERATIONS): @@ -2672,7 +2755,7 @@ class Parser(metaclass=_Parser): def _parse_is(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: index = self._index - 1 negate = self._match(TokenType.NOT) - if self._match(TokenType.DISTINCT_FROM): + if self._match_text_seq("DISTINCT", "FROM"): klass = exp.NullSafeEQ if negate else exp.NullSafeNEQ return self.expression(klass, this=this, expression=self._parse_expression()) @@ -2684,12 +2767,12 @@ class Parser(metaclass=_Parser): this = self.expression(exp.Is, this=this, expression=expression) return self.expression(exp.Not, this=this) if negate else this - def _parse_in(self, this: t.Optional[exp.Expression]) -> exp.Expression: + def _parse_in(self, this: t.Optional[exp.Expression], alias: bool = False) -> exp.In: unnest = self._parse_unnest() if unnest: this = self.expression(exp.In, this=this, unnest=unnest) elif self._match(TokenType.L_PAREN): - expressions = self._parse_csv(self._parse_select_or_expression) + expressions = self._parse_csv(lambda: self._parse_select_or_expression(alias=alias)) if len(expressions) == 1 and isinstance(expressions[0], exp.Subqueryable): this = self.expression(exp.In, this=this, query=expressions[0]) @@ -2722,15 +2805,19 @@ class Parser(metaclass=_Parser): # Most dialects support, e.g., the form INTERVAL '5' day, thus we try to parse # each INTERVAL expression into this canonical form so it's easy to transpile - if this and isinstance(this, exp.Literal): - if this.is_number: - this = exp.Literal.string(this.name) - - # Try to not clutter Snowflake's multi-part intervals like INTERVAL '1 day, 1 year' + if this and this.is_number: + this = exp.Literal.string(this.name) + elif this and this.is_string: parts = this.name.split() - if not unit and len(parts) <= 2: - this = exp.Literal.string(seq_get(parts, 0)) - unit = self.expression(exp.Var, this=seq_get(parts, 1)) + + if len(parts) == 2: + if unit: + # this is not actually a unit, it's something else + unit = None + self._retreat(self._index - 1) + else: + this = exp.Literal.string(parts[0]) + unit = self.expression(exp.Var, this=parts[1]) return self.expression(exp.Interval, this=this, unit=unit) @@ -2783,13 +2870,22 @@ class Parser(metaclass=_Parser): if parser: return parser(self, this, data_type) return self.expression(exp.Cast, this=this, to=data_type) - if not data_type.args.get("expressions"): + if not data_type.expressions: self._retreat(index) return self._parse_column() - return data_type + return self._parse_column_ops(data_type) return this + def _parse_type_size(self) -> t.Optional[exp.Expression]: + this = self._parse_type() + if not this: + return None + + return self.expression( + exp.DataTypeSize, this=this, expression=self._parse_var(any_token=True) + ) + def _parse_types(self, check_func: bool = False) -> t.Optional[exp.Expression]: index = self._index @@ -2814,7 +2910,7 @@ class Parser(metaclass=_Parser): elif nested: expressions = self._parse_csv(self._parse_types) else: - expressions = self._parse_csv(self._parse_conjunction) + expressions = self._parse_csv(self._parse_type_size) if not expressions or not self._match(TokenType.R_PAREN): self._retreat(index) @@ -2858,13 +2954,14 @@ class Parser(metaclass=_Parser): value: t.Optional[exp.Expression] = None if type_token in self.TIMESTAMPS: - if self._match(TokenType.WITH_TIME_ZONE) or type_token == TokenType.TIMESTAMPTZ: + if self._match_text_seq("WITH", "TIME", "ZONE") or type_token == TokenType.TIMESTAMPTZ: value = exp.DataType(this=exp.DataType.Type.TIMESTAMPTZ, expressions=expressions) elif ( - self._match(TokenType.WITH_LOCAL_TIME_ZONE) or type_token == TokenType.TIMESTAMPLTZ + self._match_text_seq("WITH", "LOCAL", "TIME", "ZONE") + or type_token == TokenType.TIMESTAMPLTZ ): value = exp.DataType(this=exp.DataType.Type.TIMESTAMPLTZ, expressions=expressions) - elif self._match(TokenType.WITHOUT_TIME_ZONE): + elif self._match_text_seq("WITHOUT", "TIME", "ZONE"): if type_token == TokenType.TIME: value = exp.DataType(this=exp.DataType.Type.TIME, expressions=expressions) else: @@ -2909,7 +3006,7 @@ class Parser(metaclass=_Parser): return self._parse_column_def(this) def _parse_at_time_zone(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: - if not self._match(TokenType.AT_TIME_ZONE): + if not self._match_text_seq("AT", "TIME", "ZONE"): return this return self.expression(exp.AtTimeZone, this=this, zone=self._parse_unary()) @@ -2919,6 +3016,9 @@ class Parser(metaclass=_Parser): this = self.expression(exp.Column, this=this) elif not this: return self._parse_bracket(this) + return self._parse_column_ops(this) + + def _parse_column_ops(self, this: exp.Expression) -> exp.Expression: this = self._parse_bracket(this) while self._match_set(self.COLUMN_OPERATORS): @@ -2929,7 +3029,7 @@ class Parser(metaclass=_Parser): field = self._parse_types() if not field: self.raise_error("Expected type") - elif op: + elif op and self._curr: self._advance() value = self._prev.text field = ( @@ -2963,7 +3063,6 @@ class Parser(metaclass=_Parser): else: this = self.expression(exp.Dot, this=this, expression=field) this = self._parse_bracket(this) - return this def _parse_primary(self) -> t.Optional[exp.Expression]: @@ -2989,12 +3088,9 @@ class Parser(metaclass=_Parser): if query: expressions = [query] else: - expressions = self._parse_csv( - lambda: self._parse_alias(self._parse_conjunction(), explicit=True) - ) + expressions = self._parse_csv(self._parse_expression) - this = seq_get(expressions, 0) - self._parse_query_modifiers(this) + this = self._parse_query_modifiers(seq_get(expressions, 0)) if isinstance(this, exp.Subqueryable): this = self._parse_set_operations( @@ -3065,20 +3161,12 @@ class Parser(metaclass=_Parser): functions = self.FUNCTIONS function = functions.get(upper) - args = self._parse_csv(self._parse_lambda) - if function and not anonymous: - # Clickhouse supports function calls like foo(x, y)(z), so for these we need to also parse the - # second parameter list (i.e. "(z)") and the corresponding function will receive both arg lists. - if count_params(function) == 2: - params = None - if self._match_pair(TokenType.R_PAREN, TokenType.L_PAREN): - params = self._parse_csv(self._parse_lambda) - - this = function(args, params) - else: - this = function(args) + alias = upper in self.FUNCTIONS_WITH_ALIASED_ARGS + args = self._parse_csv(lambda: self._parse_lambda(alias=alias)) + if function and not anonymous: + this = function(args) self.validate_expression(this, args) else: this = self.expression(exp.Anonymous, this=this, expressions=args) @@ -3113,9 +3201,6 @@ class Parser(metaclass=_Parser): return self.expression(exp.Identifier, this=token.text) - def _parse_national(self, token: Token) -> exp.Expression: - return self.expression(exp.National, this=exp.Literal.string(token.text)) - def _parse_session_parameter(self) -> exp.Expression: kind = None this = self._parse_id_var() or self._parse_primary() @@ -3126,7 +3211,7 @@ class Parser(metaclass=_Parser): return self.expression(exp.SessionParameter, this=this, kind=kind) - def _parse_lambda(self) -> t.Optional[exp.Expression]: + def _parse_lambda(self, alias: bool = False) -> t.Optional[exp.Expression]: index = self._index if self._match(TokenType.L_PAREN): @@ -3149,7 +3234,7 @@ class Parser(metaclass=_Parser): exp.Distinct, expressions=self._parse_csv(self._parse_conjunction) ) else: - this = self._parse_select_or_expression() + this = self._parse_select_or_expression(alias=alias) if isinstance(this, exp.EQ): left = this.this @@ -3161,13 +3246,15 @@ class Parser(metaclass=_Parser): def _parse_schema(self, this: t.Optional[exp.Expression] = None) -> t.Optional[exp.Expression]: index = self._index - try: - if self._parse_select(nested=True): - return this - except Exception: - pass - finally: - self._retreat(index) + if not self.errors: + try: + if self._parse_select(nested=True): + return this + except ParseError: + pass + finally: + self.errors.clear() + self._retreat(index) if not self._match(TokenType.L_PAREN): return this @@ -3227,13 +3314,18 @@ class Parser(metaclass=_Parser): return self.expression(exp.CompressColumnConstraint, this=self._parse_bitwise()) def _parse_generated_as_identity(self) -> exp.Expression: - if self._match(TokenType.BY_DEFAULT): - this = self.expression(exp.GeneratedAsIdentityColumnConstraint, this=False) + if self._match_text_seq("BY", "DEFAULT"): + on_null = self._match_pair(TokenType.ON, TokenType.NULL) + this = self.expression( + exp.GeneratedAsIdentityColumnConstraint, this=False, on_null=on_null + ) else: self._match_text_seq("ALWAYS") this = self.expression(exp.GeneratedAsIdentityColumnConstraint, this=True) - self._match_text_seq("AS", "IDENTITY") + self._match(TokenType.ALIAS) + identity = self._match_text_seq("IDENTITY") + if self._match(TokenType.L_PAREN): if self._match_text_seq("START", "WITH"): this.set("start", self._parse_bitwise()) @@ -3249,6 +3341,9 @@ class Parser(metaclass=_Parser): elif self._match_text_seq("NO", "CYCLE"): this.set("cycle", False) + if not identity: + this.set("expression", self._parse_bitwise()) + self._match_r_paren() return this @@ -3307,9 +3402,10 @@ class Parser(metaclass=_Parser): return self.CONSTRAINT_PARSERS[constraint](self) def _parse_unique(self) -> exp.Expression: - if not self._match(TokenType.L_PAREN, advance=False): - return self.expression(exp.UniqueColumnConstraint) - return self.expression(exp.Unique, expressions=self._parse_wrapped_id_vars()) + self._match_text_seq("KEY") + return self.expression( + exp.UniqueColumnConstraint, this=self._parse_schema(self._parse_id_var(any_token=False)) + ) def _parse_key_constraint_options(self) -> t.List[str]: options = [] @@ -3321,9 +3417,9 @@ class Parser(metaclass=_Parser): action = None on = self._advance_any() and self._prev.text - if self._match(TokenType.NO_ACTION): + if self._match_text_seq("NO", "ACTION"): action = "NO ACTION" - elif self._match(TokenType.CASCADE): + elif self._match_text_seq("CASCADE"): action = "CASCADE" elif self._match_pair(TokenType.SET, TokenType.NULL): action = "SET NULL" @@ -3348,7 +3444,7 @@ class Parser(metaclass=_Parser): return options - def _parse_references(self, match=True) -> t.Optional[exp.Expression]: + def _parse_references(self, match: bool = True) -> t.Optional[exp.Expression]: if match and not self._match(TokenType.REFERENCES): return None @@ -3372,7 +3468,7 @@ class Parser(metaclass=_Parser): kind = self._prev.text.lower() - if self._match(TokenType.NO_ACTION): + if self._match_text_seq("NO", "ACTION"): action = "NO ACTION" elif self._match(TokenType.SET): self._match_set((TokenType.NULL, TokenType.DEFAULT)) @@ -3396,11 +3492,19 @@ class Parser(metaclass=_Parser): if not self._match(TokenType.L_PAREN, advance=False): return self.expression(exp.PrimaryKeyColumnConstraint, desc=desc) - expressions = self._parse_wrapped_id_vars() + expressions = self._parse_wrapped_csv(self._parse_field) options = self._parse_key_constraint_options() return self.expression(exp.PrimaryKey, expressions=expressions, options=options) + @t.overload + def _parse_bracket(self, this: exp.Expression) -> exp.Expression: + ... + + @t.overload def _parse_bracket(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: + ... + + def _parse_bracket(self, this): if not self._match_set((TokenType.L_BRACKET, TokenType.L_BRACE)): return this @@ -3493,7 +3597,12 @@ class Parser(metaclass=_Parser): this = self._parse_conjunction() if not self._match(TokenType.ALIAS): - self.raise_error("Expected AS after CAST") + if self._match(TokenType.COMMA): + return self.expression( + exp.CastToStrType, this=this, expression=self._parse_string() + ) + else: + self.raise_error("Expected AS after CAST") to = self._parse_types() @@ -3524,7 +3633,7 @@ class Parser(metaclass=_Parser): # Checks if we can parse an order clause: WITHIN GROUP (ORDER BY [ASC | DESC]). # This is done "manually", instead of letting _parse_window parse it into an exp.WithinGroup node, so that # the STRING_AGG call is parsed like in MySQL / SQLite and can thus be transpiled more easily to them. - if not self._match(TokenType.WITHIN_GROUP): + if not self._match_text_seq("WITHIN", "GROUP"): self._retreat(index) this = exp.GroupConcat.from_arg_list(args) self.validate_expression(this, args) @@ -3674,6 +3783,27 @@ class Parser(metaclass=_Parser): exp.MatchAgainst, this=this, expressions=expressions, modifier=modifier ) + # https://learn.microsoft.com/en-us/sql/t-sql/functions/openjson-transact-sql?view=sql-server-ver16 + def _parse_open_json(self) -> exp.Expression: + this = self._parse_bitwise() + path = self._match(TokenType.COMMA) and self._parse_string() + + def _parse_open_json_column_def() -> exp.Expression: + this = self._parse_field(any_token=True) + kind = self._parse_types() + path = self._parse_string() + as_json = self._match_pair(TokenType.ALIAS, TokenType.JSON) + return self.expression( + exp.OpenJSONColumnDef, this=this, kind=kind, path=path, as_json=as_json + ) + + expressions = None + if self._match_pair(TokenType.R_PAREN, TokenType.WITH): + self._match_l_paren() + expressions = self._parse_csv(_parse_open_json_column_def) + + return self.expression(exp.OpenJSON, this=this, path=path, expressions=expressions) + def _parse_position(self, haystack_first: bool = False) -> exp.Expression: args = self._parse_csv(self._parse_bitwise) @@ -3722,7 +3852,7 @@ class Parser(metaclass=_Parser): position = None collation = None - if self._match_set(self.TRIM_TYPES): + if self._match_texts(self.TRIM_TYPES): position = self._prev.text.upper() expression = self._parse_bitwise() @@ -3752,9 +3882,9 @@ class Parser(metaclass=_Parser): def _parse_respect_or_ignore_nulls( self, this: t.Optional[exp.Expression] ) -> t.Optional[exp.Expression]: - if self._match(TokenType.IGNORE_NULLS): + if self._match_text_seq("IGNORE", "NULLS"): return self.expression(exp.IgnoreNulls, this=this) - if self._match(TokenType.RESPECT_NULLS): + if self._match_text_seq("RESPECT", "NULLS"): return self.expression(exp.RespectNulls, this=this) return this @@ -3767,7 +3897,7 @@ class Parser(metaclass=_Parser): # T-SQL allows the OVER (...) syntax after WITHIN GROUP. # https://learn.microsoft.com/en-us/sql/t-sql/functions/percentile-disc-transact-sql?view=sql-server-ver16 - if self._match(TokenType.WITHIN_GROUP): + if self._match_text_seq("WITHIN", "GROUP"): order = self._parse_wrapped(self._parse_order) this = self.expression(exp.WithinGroup, this=this, expression=order) @@ -3846,10 +3976,11 @@ class Parser(metaclass=_Parser): return { "value": ( - self._match_set((TokenType.UNBOUNDED, TokenType.CURRENT_ROW)) and self._prev.text - ) - or self._parse_bitwise(), - "side": self._match_set((TokenType.PRECEDING, TokenType.FOLLOWING)) and self._prev.text, + (self._match_text_seq("UNBOUNDED") and "UNBOUNDED") + or (self._match_text_seq("CURRENT", "ROW") and "CURRENT ROW") + or self._parse_bitwise() + ), + "side": self._match_texts(self.WINDOW_SIDES) and self._prev.text, } def _parse_alias( @@ -3956,7 +4087,7 @@ class Parser(metaclass=_Parser): def _parse_parameter(self) -> exp.Expression: wrapped = self._match(TokenType.L_BRACE) - this = self._parse_var() or self._parse_primary() + this = self._parse_var() or self._parse_identifier() or self._parse_primary() self._match(TokenType.R_BRACE) return self.expression(exp.Parameter, this=this, wrapped=wrapped) @@ -4011,26 +4142,33 @@ class Parser(metaclass=_Parser): return this - def _parse_wrapped_id_vars(self) -> t.List[t.Optional[exp.Expression]]: - return self._parse_wrapped_csv(self._parse_id_var) + def _parse_wrapped_id_vars(self, optional: bool = False) -> t.List[t.Optional[exp.Expression]]: + return self._parse_wrapped_csv(self._parse_id_var, optional=optional) def _parse_wrapped_csv( - self, parse_method: t.Callable, sep: TokenType = TokenType.COMMA + self, parse_method: t.Callable, sep: TokenType = TokenType.COMMA, optional: bool = False ) -> t.List[t.Optional[exp.Expression]]: - return self._parse_wrapped(lambda: self._parse_csv(parse_method, sep=sep)) + return self._parse_wrapped( + lambda: self._parse_csv(parse_method, sep=sep), optional=optional + ) - def _parse_wrapped(self, parse_method: t.Callable) -> t.Any: - self._match_l_paren() + def _parse_wrapped(self, parse_method: t.Callable, optional: bool = False) -> t.Any: + wrapped = self._match(TokenType.L_PAREN) + if not wrapped and not optional: + self.raise_error("Expecting (") parse_result = parse_method() - self._match_r_paren() + if wrapped: + self._match_r_paren() return parse_result - def _parse_select_or_expression(self) -> t.Optional[exp.Expression]: - return self._parse_select() or self._parse_set_operations(self._parse_expression()) + def _parse_select_or_expression(self, alias: bool = False) -> t.Optional[exp.Expression]: + return self._parse_select() or self._parse_set_operations( + self._parse_expression() if alias else self._parse_conjunction() + ) def _parse_ddl_select(self) -> t.Optional[exp.Expression]: - return self._parse_set_operations( - self._parse_select(nested=True, parse_subquery_alias=False) + return self._parse_query_modifiers( + self._parse_set_operations(self._parse_select(nested=True, parse_subquery_alias=False)) ) def _parse_transaction(self) -> exp.Expression: @@ -4391,11 +4529,11 @@ class Parser(metaclass=_Parser): return None - def _match_l_paren(self, expression=None): + def _match_l_paren(self, expression: t.Optional[exp.Expression] = None) -> None: if not self._match(TokenType.L_PAREN, expression=expression): self.raise_error("Expecting (") - def _match_r_paren(self, expression=None): + def _match_r_paren(self, expression: t.Optional[exp.Expression] = None) -> None: if not self._match(TokenType.R_PAREN, expression=expression): self.raise_error("Expecting )") @@ -4420,6 +4558,16 @@ class Parser(metaclass=_Parser): return True + @t.overload + def _replace_columns_with_dots(self, this: exp.Expression) -> exp.Expression: + ... + + @t.overload + def _replace_columns_with_dots( + self, this: t.Optional[exp.Expression] + ) -> t.Optional[exp.Expression]: + ... + def _replace_columns_with_dots(self, this): if isinstance(this, exp.Dot): exp.replace_children(this, self._replace_columns_with_dots) @@ -4433,9 +4581,15 @@ class Parser(metaclass=_Parser): ) elif isinstance(this, exp.Identifier): this = self.expression(exp.Var, this=this.name) + return this - def _replace_lambda(self, node, lambda_variables): + def _replace_lambda( + self, node: t.Optional[exp.Expression], lambda_variables: t.Set[str] + ) -> t.Optional[exp.Expression]: + if not node: + return node + for column in node.find_all(exp.Column): if column.parts[0].name in lambda_variables: dot_or_id = column.to_dot() if column.table else column.this -- cgit v1.2.3