From 8fe30fd23dc37ec3516e530a86d1c4b604e71241 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Sun, 10 Dec 2023 11:46:01 +0100 Subject: Merging upstream version 20.1.0. Signed-off-by: Daniel Baumann --- sqlglot/parser.py | 297 +++++++++++++++++++++++++++++++++--------------------- 1 file changed, 183 insertions(+), 114 deletions(-) (limited to 'sqlglot/parser.py') diff --git a/sqlglot/parser.py b/sqlglot/parser.py index 1dab600..c7e27a3 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -13,6 +13,7 @@ from sqlglot.trie import TrieResult, in_trie, new_trie if t.TYPE_CHECKING: from sqlglot._typing import E + from sqlglot.dialects.dialect import Dialect, DialectType logger = logging.getLogger("sqlglot") @@ -46,6 +47,19 @@ def binary_range_parser( ) +def parse_logarithm(args: t.List, dialect: Dialect) -> exp.Func: + # Default argument order is base, expression + this = seq_get(args, 0) + expression = seq_get(args, 1) + + if expression: + if not dialect.LOG_BASE_FIRST: + this, expression = expression, this + return exp.Log(this=this, expression=expression) + + return (exp.Ln if dialect.parser_class.LOG_DEFAULTS_TO_LN else exp.Log)(this=this) + + class _Parser(type): def __new__(cls, clsname, bases, attrs): klass = super().__new__(cls, clsname, bases, attrs) @@ -72,13 +86,24 @@ class Parser(metaclass=_Parser): """ FUNCTIONS: t.Dict[str, t.Callable] = { - **{name: f.from_arg_list for f in exp.ALL_FUNCTIONS for name in f.sql_names()}, + **{name: func.from_arg_list for name, func in exp.FUNCTION_BY_NAME.items()}, + "CONCAT": lambda args, dialect: exp.Concat( + expressions=args, + safe=not dialect.STRICT_STRING_CONCAT, + coalesce=dialect.CONCAT_COALESCE, + ), + "CONCAT_WS": lambda args, dialect: exp.ConcatWs( + expressions=args, + safe=not dialect.STRICT_STRING_CONCAT, + coalesce=dialect.CONCAT_COALESCE, + ), "DATE_TO_DATE_STR": lambda args: exp.Cast( this=seq_get(args, 0), to=exp.DataType(this=exp.DataType.Type.TEXT), ), "GLOB": lambda args: exp.Glob(this=seq_get(args, 1), expression=seq_get(args, 0)), "LIKE": parse_like, + "LOG": parse_logarithm, "TIME_TO_TIME_STR": lambda args: exp.Cast( this=seq_get(args, 0), to=exp.DataType(this=exp.DataType.Type.TEXT), @@ -229,7 +254,7 @@ class Parser(metaclass=_Parser): TokenType.SOME: exp.Any, } - RESERVED_KEYWORDS = { + RESERVED_TOKENS = { *Tokenizer.SINGLE_TOKENS.values(), TokenType.SELECT, } @@ -245,9 +270,11 @@ class Parser(metaclass=_Parser): CREATABLES = { TokenType.COLUMN, + TokenType.CONSTRAINT, TokenType.FUNCTION, TokenType.INDEX, TokenType.PROCEDURE, + TokenType.FOREIGN_KEY, *DB_CREATABLES, } @@ -291,6 +318,7 @@ class Parser(metaclass=_Parser): TokenType.NATURAL, TokenType.NEXT, TokenType.OFFSET, + TokenType.OPERATOR, TokenType.ORDINALITY, TokenType.OVERLAPS, TokenType.OVERWRITE, @@ -299,7 +327,10 @@ class Parser(metaclass=_Parser): TokenType.PIVOT, TokenType.PRAGMA, TokenType.RANGE, + TokenType.RECURSIVE, TokenType.REFERENCES, + TokenType.REFRESH, + TokenType.REPLACE, TokenType.RIGHT, TokenType.ROW, TokenType.ROWS, @@ -390,6 +421,7 @@ class Parser(metaclass=_Parser): } EQUALITY = { + TokenType.COLON_EQ: exp.PropertyEQ, TokenType.EQ: exp.EQ, TokenType.NEQ: exp.NEQ, TokenType.NULLSAFE_EQ: exp.NullSafeEQ, @@ -406,7 +438,6 @@ class Parser(metaclass=_Parser): TokenType.AMP: exp.BitwiseAnd, TokenType.CARET: exp.BitwiseXor, TokenType.PIPE: exp.BitwiseOr, - TokenType.DPIPE: exp.DPipe, } TERM = { @@ -423,6 +454,8 @@ class Parser(metaclass=_Parser): TokenType.STAR: exp.Mul, } + EXPONENT: t.Dict[TokenType, t.Type[exp.Expression]] = {} + TIMES = { TokenType.TIME, TokenType.TIMETZ, @@ -558,6 +591,7 @@ class Parser(metaclass=_Parser): TokenType.MERGE: lambda self: self._parse_merge(), TokenType.PIVOT: lambda self: self._parse_simplified_pivot(), TokenType.PRAGMA: lambda self: self.expression(exp.Pragma, this=self._parse_expression()), + TokenType.REFRESH: lambda self: self._parse_refresh(), TokenType.ROLLBACK: lambda self: self._parse_commit_or_rollback(), TokenType.SET: lambda self: self._parse_set(), TokenType.UNCACHE: lambda self: self._parse_uncache(), @@ -697,6 +731,7 @@ class Parser(metaclass=_Parser): exp.StabilityProperty, this=exp.Literal.string("STABLE") ), "STORED": lambda self: self._parse_stored(), + "SYSTEM_VERSIONING": lambda self: self._parse_system_versioning_property(), "TBLPROPERTIES": lambda self: self._parse_wrapped_csv(self._parse_property), "TEMP": lambda self: self.expression(exp.TemporaryProperty), "TEMPORARY": lambda self: self.expression(exp.TemporaryProperty), @@ -754,6 +789,7 @@ class Parser(metaclass=_Parser): ) or self.expression(exp.OnProperty, this=self._parse_id_var()), "PATH": lambda self: self.expression(exp.PathColumnConstraint, this=self._parse_string()), + "PERIOD": lambda self: self._parse_period_for_system_time(), "PRIMARY KEY": lambda self: self._parse_primary_key(), "REFERENCES": lambda self: self._parse_references(match=False), "TITLE": lambda self: self.expression( @@ -775,7 +811,7 @@ class Parser(metaclass=_Parser): "RENAME": lambda self: self._parse_alter_table_rename(), } - SCHEMA_UNNAMED_CONSTRAINTS = {"CHECK", "FOREIGN KEY", "LIKE", "PRIMARY KEY", "UNIQUE"} + SCHEMA_UNNAMED_CONSTRAINTS = {"CHECK", "FOREIGN KEY", "LIKE", "PRIMARY KEY", "UNIQUE", "PERIOD"} NO_PAREN_FUNCTION_PARSERS = { "ANY": lambda self: self.expression(exp.Any, this=self._parse_bitwise()), @@ -794,14 +830,11 @@ class Parser(metaclass=_Parser): FUNCTION_PARSERS = { "ANY_VALUE": lambda self: self._parse_any_value(), "CAST": lambda self: self._parse_cast(self.STRICT_CAST), - "CONCAT": lambda self: self._parse_concat(), - "CONCAT_WS": lambda self: self._parse_concat_ws(), "CONVERT": lambda self: self._parse_convert(self.STRICT_CAST), "DECODE": lambda self: self._parse_decode(), "EXTRACT": lambda self: self._parse_extract(), "JSON_OBJECT": lambda self: self._parse_json_object(), "JSON_TABLE": lambda self: self._parse_json_table(), - "LOG": lambda self: self._parse_logarithm(), "MATCH": lambda self: self._parse_match_against(), "OPENJSON": lambda self: self._parse_open_json(), "POSITION": lambda self: self._parse_position(), @@ -877,6 +910,7 @@ class Parser(metaclass=_Parser): CLONE_KINDS = {"TIMESTAMP", "OFFSET", "STATEMENT"} OPCLASS_FOLLOW_KEYWORDS = {"ASC", "DESC", "NULLS"} + OPTYPE_FOLLOW_TOKENS = {TokenType.COMMA, TokenType.R_PAREN} TABLE_INDEX_HINT_TOKENS = {TokenType.FORCE, TokenType.IGNORE, TokenType.USE} @@ -896,17 +930,13 @@ class Parser(metaclass=_Parser): STRICT_CAST = True - # A NULL arg in CONCAT yields NULL by default - CONCAT_NULL_OUTPUTS_STRING = False - PREFIXED_PIVOT_COLUMNS = False IDENTIFY_PIVOT_STRINGS = False - LOG_BASE_FIRST = True LOG_DEFAULTS_TO_LN = False # Whether or not ADD is present for each column added by ALTER TABLE - ALTER_TABLE_ADD_COLUMN_KEYWORD = True + ALTER_TABLE_ADD_REQUIRED_FOR_EACH_COLUMN = True # Whether or not the table sample clause expects CSV syntax TABLESAMPLE_CSV = False @@ -921,6 +951,7 @@ class Parser(metaclass=_Parser): "error_level", "error_message_context", "max_errors", + "dialect", "sql", "errors", "_tokens", @@ -929,35 +960,25 @@ class Parser(metaclass=_Parser): "_next", "_prev", "_prev_comments", - "_tokenizer", ) # Autofilled - TOKENIZER_CLASS: t.Type[Tokenizer] = Tokenizer - INDEX_OFFSET: int = 0 - UNNEST_COLUMN_ONLY: bool = False - ALIAS_POST_TABLESAMPLE: bool = False - STRICT_STRING_CONCAT = False - SUPPORTS_USER_DEFINED_TYPES = True - NORMALIZE_FUNCTIONS = "upper" - NULL_ORDERING: str = "nulls_are_small" SHOW_TRIE: t.Dict = {} SET_TRIE: t.Dict = {} - FORMAT_MAPPING: t.Dict[str, str] = {} - FORMAT_TRIE: t.Dict = {} - TIME_MAPPING: t.Dict[str, str] = {} - TIME_TRIE: t.Dict = {} def __init__( self, error_level: t.Optional[ErrorLevel] = None, error_message_context: int = 100, max_errors: int = 3, + dialect: DialectType = None, ): + from sqlglot.dialects import Dialect + self.error_level = error_level or ErrorLevel.IMMEDIATE self.error_message_context = error_message_context self.max_errors = max_errors - self._tokenizer = self.TOKENIZER_CLASS() + self.dialect = Dialect.get_or_raise(dialect) self.reset() def reset(self): @@ -1384,7 +1405,7 @@ class Parser(metaclass=_Parser): if self._match_texts(self.CLONE_KEYWORDS): copy = self._prev.text.lower() == "copy" clone = self._parse_table(schema=True) - when = self._match_texts({"AT", "BEFORE"}) and self._prev.text.upper() + when = self._match_texts(("AT", "BEFORE")) and self._prev.text.upper() clone_kind = ( self._match(TokenType.L_PAREN) and self._match_texts(self.CLONE_KINDS) @@ -1524,6 +1545,22 @@ class Parser(metaclass=_Parser): return self.expression(exp.StabilityProperty, this=exp.Literal.string("VOLATILE")) + def _parse_system_versioning_property(self) -> exp.WithSystemVersioningProperty: + self._match_pair(TokenType.EQ, TokenType.ON) + + prop = self.expression(exp.WithSystemVersioningProperty) + if self._match(TokenType.L_PAREN): + self._match_text_seq("HISTORY_TABLE", "=") + prop.set("this", self._parse_table_parts()) + + if self._match(TokenType.COMMA): + self._match_text_seq("DATA_CONSISTENCY_CHECK", "=") + prop.set("expression", self._advance_any() and self._prev.text.upper()) + + self._match_r_paren() + + return prop + def _parse_with_property( self, ) -> t.Optional[exp.Expression] | t.List[exp.Expression]: @@ -2140,7 +2177,11 @@ class Parser(metaclass=_Parser): return self._parse_expressions() def _parse_select( - self, nested: bool = False, table: bool = False, parse_subquery_alias: bool = True + self, + nested: bool = False, + table: bool = False, + parse_subquery_alias: bool = True, + parse_set_operation: bool = True, ) -> t.Optional[exp.Expression]: cte = self._parse_with() @@ -2216,7 +2257,11 @@ class Parser(metaclass=_Parser): t.cast(exp.From, self._parse_from(skip_from_token=True)) ) else: - this = self._parse_table() if table else self._parse_select(nested=True) + this = ( + self._parse_table() + if table + else self._parse_select(nested=True, parse_set_operation=False) + ) this = self._parse_set_operations(self._parse_query_modifiers(this)) self._match_r_paren() @@ -2235,7 +2280,9 @@ class Parser(metaclass=_Parser): else: this = None - return self._parse_set_operations(this) + if parse_set_operation: + return self._parse_set_operations(this) + return this def _parse_with(self, skip_with_token: bool = False) -> t.Optional[exp.With]: if not skip_with_token and not self._match(TokenType.WITH): @@ -2563,9 +2610,8 @@ class Parser(metaclass=_Parser): if self._match_texts(self.OPCLASS_FOLLOW_KEYWORDS, advance=False): return this - opclass = self._parse_var(any_token=True) - if opclass: - return self.expression(exp.Opclass, this=this, expression=opclass) + if not self._match_set(self.OPTYPE_FOLLOW_TOKENS, advance=False): + return self.expression(exp.Opclass, this=this, expression=self._parse_table_parts()) return this @@ -2630,7 +2676,7 @@ class Parser(metaclass=_Parser): while self._match_set(self.TABLE_INDEX_HINT_TOKENS): hint = exp.IndexTableHint(this=self._prev.text.upper()) - self._match_texts({"INDEX", "KEY"}) + self._match_texts(("INDEX", "KEY")) if self._match(TokenType.FOR): hint.set("target", self._advance_any() and self._prev.text.upper()) @@ -2650,7 +2696,7 @@ class Parser(metaclass=_Parser): def _parse_table_parts(self, schema: bool = False) -> exp.Table: catalog = None db = None - table = self._parse_table_part(schema=schema) + table: t.Optional[exp.Expression | str] = self._parse_table_part(schema=schema) while self._match(TokenType.DOT): if catalog: @@ -2661,7 +2707,7 @@ class Parser(metaclass=_Parser): else: catalog = db db = table - table = self._parse_table_part(schema=schema) + table = self._parse_table_part(schema=schema) or "" if not table: self.raise_error(f"Expected table name but got {self._curr}") @@ -2709,7 +2755,7 @@ class Parser(metaclass=_Parser): if version: this.set("version", version) - if self.ALIAS_POST_TABLESAMPLE: + if self.dialect.ALIAS_POST_TABLESAMPLE: table_sample = self._parse_table_sample() alias = self._parse_table_alias(alias_tokens=alias_tokens or self.TABLE_ALIAS_TOKENS) @@ -2724,7 +2770,7 @@ class Parser(metaclass=_Parser): if not this.args.get("pivots"): this.set("pivots", self._parse_pivots()) - if not self.ALIAS_POST_TABLESAMPLE: + if not self.dialect.ALIAS_POST_TABLESAMPLE: table_sample = self._parse_table_sample() if table_sample: @@ -2776,13 +2822,13 @@ class Parser(metaclass=_Parser): if not self._match(TokenType.UNNEST): return None - expressions = self._parse_wrapped_csv(self._parse_type) + expressions = self._parse_wrapped_csv(self._parse_equality) offset = self._match_pair(TokenType.WITH, TokenType.ORDINALITY) alias = self._parse_table_alias() if with_alias else None if alias: - if self.UNNEST_COLUMN_ONLY: + if self.dialect.UNNEST_COLUMN_ONLY: if alias.args.get("columns"): self.raise_error("Unexpected extra column alias in unnest.") @@ -2845,7 +2891,7 @@ class Parser(metaclass=_Parser): num = ( self._parse_factor() if self._match(TokenType.NUMBER, advance=False) - else self._parse_primary() + else self._parse_primary() or self._parse_placeholder() ) if self._match_text_seq("BUCKET"): @@ -3108,10 +3154,10 @@ class Parser(metaclass=_Parser): if ( not explicitly_null_ordered and ( - (not desc and self.NULL_ORDERING == "nulls_are_small") - or (desc and self.NULL_ORDERING != "nulls_are_small") + (not desc and self.dialect.NULL_ORDERING == "nulls_are_small") + or (desc and self.dialect.NULL_ORDERING != "nulls_are_small") ) - and self.NULL_ORDERING != "nulls_are_last" + and self.dialect.NULL_ORDERING != "nulls_are_last" ): nulls_first = True @@ -3124,7 +3170,7 @@ class Parser(metaclass=_Parser): comments = self._prev_comments if top: limit_paren = self._match(TokenType.L_PAREN) - expression = self._parse_number() + expression = self._parse_term() if limit_paren else self._parse_number() if limit_paren: self._match_r_paren() @@ -3225,7 +3271,9 @@ class Parser(metaclass=_Parser): this=this, distinct=self._match(TokenType.DISTINCT) or not self._match(TokenType.ALL), by_name=self._match_text_seq("BY", "NAME"), - expression=self._parse_set_operations(self._parse_select(nested=True)), + expression=self._parse_set_operations( + self._parse_select(nested=True, parse_set_operation=False) + ), ) def _parse_expression(self) -> t.Optional[exp.Expression]: @@ -3287,7 +3335,8 @@ class Parser(metaclass=_Parser): unnest = self._parse_unnest(with_alias=False) if unnest: this = self.expression(exp.In, this=this, unnest=unnest) - elif self._match(TokenType.L_PAREN): + elif self._match_set((TokenType.L_PAREN, TokenType.L_BRACKET)): + matched_l_paren = self._prev.token_type == TokenType.L_PAREN expressions = self._parse_csv(lambda: self._parse_select_or_expression(alias=alias)) if len(expressions) == 1 and isinstance(expressions[0], exp.Subqueryable): @@ -3295,13 +3344,16 @@ class Parser(metaclass=_Parser): else: this = self.expression(exp.In, this=this, expressions=expressions) - self._match_r_paren(this) + if matched_l_paren: + self._match_r_paren(this) + elif not self._match(TokenType.R_BRACKET, expression=this): + self.raise_error("Expecting ]") else: this = self.expression(exp.In, this=this, field=self._parse_field()) return this - def _parse_between(self, this: exp.Expression) -> exp.Between: + def _parse_between(self, this: t.Optional[exp.Expression]) -> exp.Between: low = self._parse_bitwise() self._match(TokenType.AND) high = self._parse_bitwise() @@ -3357,6 +3409,13 @@ class Parser(metaclass=_Parser): this=this, expression=self._parse_term(), ) + elif self.dialect.DPIPE_IS_STRING_CONCAT and self._match(TokenType.DPIPE): + this = self.expression( + exp.DPipe, + this=this, + expression=self._parse_term(), + safe=not self.dialect.STRICT_STRING_CONCAT, + ) elif self._match(TokenType.DQMARK): this = self.expression(exp.Coalesce, this=this, expressions=self._parse_term()) elif self._match_pair(TokenType.LT, TokenType.LT): @@ -3376,7 +3435,17 @@ class Parser(metaclass=_Parser): return self._parse_tokens(self._parse_factor, self.TERM) def _parse_factor(self) -> t.Optional[exp.Expression]: - return self._parse_tokens(self._parse_unary, self.FACTOR) + if self.EXPONENT: + factor = self._parse_tokens(self._parse_exponent, self.FACTOR) + else: + factor = self._parse_tokens(self._parse_unary, self.FACTOR) + if isinstance(factor, exp.Div): + factor.args["typed"] = self.dialect.TYPED_DIVISION + factor.args["safe"] = self.dialect.SAFE_DIVISION + return factor + + def _parse_exponent(self) -> t.Optional[exp.Expression]: + return self._parse_tokens(self._parse_unary, self.EXPONENT) def _parse_unary(self) -> t.Optional[exp.Expression]: if self._match_set(self.UNARY_PARSERS): @@ -3427,14 +3496,14 @@ class Parser(metaclass=_Parser): ) if identifier: - tokens = self._tokenizer.tokenize(identifier.name) + tokens = self.dialect.tokenize(identifier.name) if len(tokens) != 1: self.raise_error("Unexpected identifier", self._prev) if tokens[0].token_type in self.TYPE_TOKENS: self._prev = tokens[0] - elif self.SUPPORTS_USER_DEFINED_TYPES: + elif self.dialect.SUPPORTS_USER_DEFINED_TYPES: type_name = identifier.name while self._match(TokenType.DOT): @@ -3713,6 +3782,7 @@ class Parser(metaclass=_Parser): if not self._curr: return None + comments = self._curr.comments token_type = self._curr.token_type this = self._curr.text upper = this.upper() @@ -3754,13 +3824,22 @@ class Parser(metaclass=_Parser): args = self._parse_csv(lambda: self._parse_lambda(alias=alias)) if function and not anonymous: - func = self.validate_expression(function(args), args) - if not self.NORMALIZE_FUNCTIONS: + if "dialect" in function.__code__.co_varnames: + func = function(args, dialect=self.dialect) + else: + func = function(args) + + func = self.validate_expression(func, args) + if not self.dialect.NORMALIZE_FUNCTIONS: func.meta["name"] = this + this = func else: this = self.expression(exp.Anonymous, this=this, expressions=args) + if isinstance(this, exp.Expression): + this.add_comments(comments) + self._match_r_paren(this) return self._parse_window(this) @@ -3875,6 +3954,11 @@ class Parser(metaclass=_Parser): not_null=self._match_pair(TokenType.NOT, TokenType.NULL), ) ) + elif kind and self._match_pair(TokenType.ALIAS, TokenType.L_PAREN, advance=False): + self._match(TokenType.ALIAS) + constraints.append( + self.expression(exp.TransformColumnConstraint, this=self._parse_field()) + ) while True: constraint = self._parse_column_constraint() @@ -3917,7 +4001,11 @@ class Parser(metaclass=_Parser): def _parse_generated_as_identity( self, - ) -> exp.GeneratedAsIdentityColumnConstraint | exp.ComputedColumnConstraint: + ) -> ( + exp.GeneratedAsIdentityColumnConstraint + | exp.ComputedColumnConstraint + | exp.GeneratedAsRowColumnConstraint + ): if self._match_text_seq("BY", "DEFAULT"): on_null = self._match_pair(TokenType.ON, TokenType.NULL) this = self.expression( @@ -3928,6 +4016,14 @@ class Parser(metaclass=_Parser): this = self.expression(exp.GeneratedAsIdentityColumnConstraint, this=True) self._match(TokenType.ALIAS) + + if self._match_text_seq("ROW"): + start = self._match_text_seq("START") + if not start: + self._match(TokenType.END) + hidden = self._match_text_seq("HIDDEN") + return self.expression(exp.GeneratedAsRowColumnConstraint, start=start, hidden=hidden) + identity = self._match_text_seq("IDENTITY") if self._match(TokenType.L_PAREN): @@ -4100,6 +4196,16 @@ class Parser(metaclass=_Parser): def _parse_primary_key_part(self) -> t.Optional[exp.Expression]: return self._parse_field() + def _parse_period_for_system_time(self) -> exp.PeriodForSystemTimeConstraint: + self._match(TokenType.TIMESTAMP_SNAPSHOT) + + id_vars = self._parse_wrapped_id_vars() + return self.expression( + exp.PeriodForSystemTimeConstraint, + this=seq_get(id_vars, 0), + expression=seq_get(id_vars, 1), + ) + def _parse_primary_key( self, wrapped_optional: bool = False, in_props: bool = False ) -> exp.PrimaryKeyColumnConstraint | exp.PrimaryKey: @@ -4145,7 +4251,7 @@ class Parser(metaclass=_Parser): elif not this or this.name.upper() == "ARRAY": this = self.expression(exp.Array, expressions=expressions) else: - expressions = apply_index_offset(this, expressions, -self.INDEX_OFFSET) + expressions = apply_index_offset(this, expressions, -self.dialect.INDEX_OFFSET) this = self.expression(exp.Bracket, this=this, expressions=expressions) self._add_comments(this) @@ -4259,8 +4365,8 @@ class Parser(metaclass=_Parser): format=exp.Literal.string( format_time( fmt_string.this if fmt_string else "", - self.FORMAT_MAPPING or self.TIME_MAPPING, - self.FORMAT_TRIE or self.TIME_TRIE, + self.dialect.FORMAT_MAPPING or self.dialect.TIME_MAPPING, + self.dialect.FORMAT_TRIE or self.dialect.TIME_TRIE, ) ), ) @@ -4280,30 +4386,6 @@ class Parser(metaclass=_Parser): exp.Cast if strict else exp.TryCast, this=this, to=to, format=fmt, safe=safe ) - def _parse_concat(self) -> t.Optional[exp.Expression]: - args = self._parse_csv(self._parse_conjunction) - if self.CONCAT_NULL_OUTPUTS_STRING: - args = self._ensure_string_if_null(args) - - # Some dialects (e.g. Trino) don't allow a single-argument CONCAT call, so when - # we find such a call we replace it with its argument. - if len(args) == 1: - return args[0] - - return self.expression( - exp.Concat if self.STRICT_STRING_CONCAT else exp.SafeConcat, expressions=args - ) - - def _parse_concat_ws(self) -> t.Optional[exp.Expression]: - args = self._parse_csv(self._parse_conjunction) - if len(args) < 2: - return self.expression(exp.ConcatWs, expressions=args) - delim, *values = args - if self.CONCAT_NULL_OUTPUTS_STRING: - values = self._ensure_string_if_null(values) - - return self.expression(exp.ConcatWs, expressions=[delim] + values) - def _parse_string_agg(self) -> exp.Expression: if self._match(TokenType.DISTINCT): args: t.List[t.Optional[exp.Expression]] = [ @@ -4495,19 +4577,6 @@ class Parser(metaclass=_Parser): empty_handling=empty_handling, ) - def _parse_logarithm(self) -> exp.Func: - # Default argument order is base, expression - args = self._parse_csv(self._parse_range) - - if len(args) > 1: - if not self.LOG_BASE_FIRST: - args.reverse() - return exp.Log.from_arg_list(args) - - return self.expression( - exp.Ln if self.LOG_DEFAULTS_TO_LN else exp.Log, this=seq_get(args, 0) - ) - def _parse_match_against(self) -> exp.MatchAgainst: expressions = self._parse_csv(self._parse_column) @@ -4755,6 +4824,7 @@ class Parser(metaclass=_Parser): self, this: t.Optional[exp.Expression], explicit: bool = False ) -> t.Optional[exp.Expression]: any_token = self._match(TokenType.ALIAS) + comments = self._prev_comments if explicit and not any_token: return this @@ -4762,6 +4832,7 @@ class Parser(metaclass=_Parser): if self._match(TokenType.L_PAREN): aliases = self.expression( exp.Aliases, + comments=comments, this=this, expressions=self._parse_csv(lambda: self._parse_id_var(any_token)), ) @@ -4771,7 +4842,7 @@ class Parser(metaclass=_Parser): alias = self._parse_id_var(any_token) if alias: - return self.expression(exp.Alias, this=this, alias=alias) + return self.expression(exp.Alias, comments=comments, this=this, alias=alias) return this @@ -4792,8 +4863,8 @@ class Parser(metaclass=_Parser): return None def _parse_string(self) -> t.Optional[exp.Expression]: - if self._match(TokenType.STRING): - return self.PRIMARY_PARSERS[TokenType.STRING](self, self._prev) + if self._match_set((TokenType.STRING, TokenType.RAW_STRING)): + return self.PRIMARY_PARSERS[self._prev.token_type](self, self._prev) return self._parse_placeholder() def _parse_string_as_identifier(self) -> t.Optional[exp.Identifier]: @@ -4821,7 +4892,7 @@ class Parser(metaclass=_Parser): return self._parse_placeholder() def _advance_any(self) -> t.Optional[Token]: - if self._curr and self._curr.token_type not in self.RESERVED_KEYWORDS: + if self._curr and self._curr.token_type not in self.RESERVED_TOKENS: self._advance() return self._prev return None @@ -4951,7 +5022,7 @@ class Parser(metaclass=_Parser): if self._match_texts(self.TRANSACTION_KIND): this = self._prev.text - self._match_texts({"TRANSACTION", "WORK"}) + self._match_texts(("TRANSACTION", "WORK")) modes = [] while True: @@ -4971,7 +5042,7 @@ class Parser(metaclass=_Parser): savepoint = None is_rollback = self._prev.token_type == TokenType.ROLLBACK - self._match_texts({"TRANSACTION", "WORK"}) + self._match_texts(("TRANSACTION", "WORK")) if self._match_text_seq("TO"): self._match_text_seq("SAVEPOINT") @@ -4986,6 +5057,10 @@ class Parser(metaclass=_Parser): return self.expression(exp.Commit, chain=chain) + def _parse_refresh(self) -> exp.Refresh: + self._match(TokenType.TABLE) + return self.expression(exp.Refresh, this=self._parse_string() or self._parse_table()) + def _parse_add_column(self) -> t.Optional[exp.Expression]: if not self._match_text_seq("ADD"): return None @@ -5050,10 +5125,9 @@ class Parser(metaclass=_Parser): return self._parse_csv(self._parse_add_constraint) self._retreat(index) - if not self.ALTER_TABLE_ADD_COLUMN_KEYWORD and self._match_text_seq("ADD"): - return self._parse_csv(self._parse_field_def) - - return self._parse_csv(self._parse_add_column) + if not self.ALTER_TABLE_ADD_REQUIRED_FOR_EACH_COLUMN and self._match_text_seq("ADD"): + return self._parse_wrapped_csv(self._parse_field_def, optional=True) + return self._parse_wrapped_csv(self._parse_add_column, optional=True) def _parse_alter_table_alter(self) -> exp.AlterColumn: self._match(TokenType.COLUMN) @@ -5198,7 +5272,7 @@ class Parser(metaclass=_Parser): ) -> t.Optional[exp.Expression]: index = self._index - if kind in {"GLOBAL", "SESSION"} and self._match_text_seq("TRANSACTION"): + if kind in ("GLOBAL", "SESSION") and self._match_text_seq("TRANSACTION"): return self._parse_set_transaction(global_=kind == "GLOBAL") left = self._parse_primary() or self._parse_id_var() @@ -5292,7 +5366,9 @@ class Parser(metaclass=_Parser): self._match_r_paren() return self.expression(exp.DictRange, this=this, min=min, max=max) - def _parse_comprehension(self, this: exp.Expression) -> t.Optional[exp.Comprehension]: + def _parse_comprehension( + self, this: t.Optional[exp.Expression] + ) -> t.Optional[exp.Comprehension]: index = self._index expression = self._parse_column() if not self._match(TokenType.IN): @@ -5441,10 +5517,3 @@ class Parser(metaclass=_Parser): else: column.replace(dot_or_id) return node - - def _ensure_string_if_null(self, values: t.List[exp.Expression]) -> t.List[exp.Expression]: - return [ - exp.func("COALESCE", exp.cast(value, "text"), exp.Literal.string("")) - for value in values - if value - ] -- cgit v1.2.3