From 90150543f9314be683d22a16339effd774192f6d Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Thu, 22 Sep 2022 06:31:28 +0200 Subject: Merging upstream version 6.1.1. Signed-off-by: Daniel Baumann --- sqlglot/parser.py | 404 +++++++++++++++++++++++++++--------------------------- 1 file changed, 202 insertions(+), 202 deletions(-) (limited to 'sqlglot/parser.py') diff --git a/sqlglot/parser.py b/sqlglot/parser.py index 9396c50..f46bafe 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -78,6 +78,7 @@ class Parser: TokenType.TEXT, TokenType.BINARY, TokenType.JSON, + TokenType.INTERVAL, TokenType.TIMESTAMP, TokenType.TIMESTAMPTZ, TokenType.DATETIME, @@ -85,6 +86,12 @@ class Parser: TokenType.DECIMAL, TokenType.UUID, TokenType.GEOGRAPHY, + TokenType.GEOMETRY, + TokenType.HLLSKETCH, + TokenType.SUPER, + TokenType.SERIAL, + TokenType.SMALLSERIAL, + TokenType.BIGSERIAL, *NESTED_TYPE_TOKENS, } @@ -100,13 +107,14 @@ class Parser: ID_VAR_TOKENS = { TokenType.VAR, TokenType.ALTER, + TokenType.ALWAYS, TokenType.BEGIN, + TokenType.BOTH, TokenType.BUCKET, TokenType.CACHE, TokenType.COLLATE, TokenType.COMMIT, TokenType.CONSTRAINT, - TokenType.CONVERT, TokenType.DEFAULT, TokenType.DELETE, TokenType.ENGINE, @@ -115,14 +123,19 @@ class Parser: TokenType.FALSE, TokenType.FIRST, TokenType.FOLLOWING, + TokenType.FOR, TokenType.FORMAT, TokenType.FUNCTION, + TokenType.GENERATED, + TokenType.IDENTITY, TokenType.IF, TokenType.INDEX, TokenType.ISNULL, TokenType.INTERVAL, TokenType.LAZY, + TokenType.LEADING, TokenType.LOCATION, + TokenType.NATURAL, TokenType.NEXT, TokenType.ONLY, TokenType.OPTIMIZE, @@ -141,6 +154,7 @@ class Parser: TokenType.TABLE_FORMAT, TokenType.TEMPORARY, TokenType.TOP, + TokenType.TRAILING, TokenType.TRUNCATE, TokenType.TRUE, TokenType.UNBOUNDED, @@ -150,18 +164,15 @@ class Parser: *TYPE_TOKENS, } - CASTS = { - TokenType.CAST, - TokenType.TRY_CAST, - } + TABLE_ALIAS_TOKENS = ID_VAR_TOKENS - {TokenType.NATURAL} + + TRIM_TYPES = {TokenType.LEADING, TokenType.TRAILING, TokenType.BOTH} FUNC_TOKENS = { - TokenType.CONVERT, TokenType.CURRENT_DATE, TokenType.CURRENT_DATETIME, TokenType.CURRENT_TIMESTAMP, TokenType.CURRENT_TIME, - TokenType.EXTRACT, TokenType.FILTER, TokenType.FIRST, TokenType.FORMAT, @@ -178,7 +189,6 @@ class Parser: TokenType.DATETIME, TokenType.TIMESTAMP, TokenType.TIMESTAMPTZ, - *CASTS, *NESTED_TYPE_TOKENS, *SUBQUERY_PREDICATES, } @@ -215,6 +225,7 @@ class Parser: FACTOR = { TokenType.DIV: exp.IntDiv, + TokenType.LR_ARROW: exp.Distance, TokenType.SLASH: exp.Div, TokenType.STAR: exp.Mul, } @@ -299,14 +310,13 @@ class Parser: PRIMARY_PARSERS = { TokenType.STRING: lambda _, token: exp.Literal.string(token.text), TokenType.NUMBER: lambda _, token: exp.Literal.number(token.text), - TokenType.STAR: lambda self, _: exp.Star( - **{"except": self._parse_except(), "replace": self._parse_replace()} - ), + TokenType.STAR: lambda self, _: exp.Star(**{"except": self._parse_except(), "replace": self._parse_replace()}), TokenType.NULL: lambda *_: exp.Null(), TokenType.TRUE: lambda *_: exp.Boolean(this=True), TokenType.FALSE: lambda *_: exp.Boolean(this=False), TokenType.PLACEHOLDER: lambda *_: exp.Placeholder(), TokenType.BIT_STRING: lambda _, token: exp.BitString(this=token.text), + TokenType.HEX_STRING: lambda _, token: exp.HexString(this=token.text), TokenType.INTRODUCER: lambda self, token: self.expression( exp.Introducer, this=token.text, @@ -319,13 +329,16 @@ class Parser: TokenType.IN: lambda self, this: self._parse_in(this), TokenType.IS: lambda self, this: self._parse_is(this), TokenType.LIKE: lambda self, this: self._parse_escape( - self.expression(exp.Like, this=this, expression=self._parse_type()) + self.expression(exp.Like, this=this, expression=self._parse_bitwise()) ), TokenType.ILIKE: lambda self, this: self._parse_escape( - self.expression(exp.ILike, this=this, expression=self._parse_type()) + self.expression(exp.ILike, this=this, expression=self._parse_bitwise()) ), TokenType.RLIKE: lambda self, this: self.expression( - exp.RegexpLike, this=this, expression=self._parse_type() + exp.RegexpLike, this=this, expression=self._parse_bitwise() + ), + TokenType.SIMILAR_TO: lambda self, this: self.expression( + exp.SimilarTo, this=this, expression=self._parse_bitwise() ), } @@ -363,28 +376,21 @@ class Parser: } FUNCTION_PARSERS = { - TokenType.CONVERT: lambda self, _: self._parse_convert(), - TokenType.EXTRACT: lambda self, _: self._parse_extract(), - **{ - token_type: lambda self, token_type: self._parse_cast( - self.STRICT_CAST and token_type == TokenType.CAST - ) - for token_type in CASTS - }, + "CONVERT": lambda self: self._parse_convert(), + "EXTRACT": lambda self: self._parse_extract(), + "SUBSTRING": lambda self: self._parse_substring(), + "TRIM": lambda self: self._parse_trim(), + "CAST": lambda self: self._parse_cast(self.STRICT_CAST), + "TRY_CAST": lambda self: self._parse_cast(False), } QUERY_MODIFIER_PARSERS = { - "laterals": lambda self: self._parse_laterals(), - "joins": lambda self: self._parse_joins(), "where": lambda self: self._parse_where(), "group": lambda self: self._parse_group(), "having": lambda self: self._parse_having(), "qualify": lambda self: self._parse_qualify(), - "window": lambda self: self._match(TokenType.WINDOW) - and self._parse_window(self._parse_id_var(), alias=True), - "distribute": lambda self: self._parse_sort( - TokenType.DISTRIBUTE_BY, exp.Distribute - ), + "window": lambda self: self._match(TokenType.WINDOW) and self._parse_window(self._parse_id_var(), alias=True), + "distribute": lambda self: self._parse_sort(TokenType.DISTRIBUTE_BY, exp.Distribute), "sort": lambda self: self._parse_sort(TokenType.SORT_BY, exp.Sort), "cluster": lambda self: self._parse_sort(TokenType.CLUSTER_BY, exp.Cluster), "order": lambda self: self._parse_order(), @@ -392,6 +398,8 @@ class Parser: "offset": lambda self: self._parse_offset(), } + MODIFIABLES = (exp.Subquery, exp.Subqueryable, exp.Table) + CREATABLES = {TokenType.TABLE, TokenType.VIEW, TokenType.FUNCTION, TokenType.INDEX} STRICT_CAST = True @@ -457,9 +465,7 @@ class Parser: Returns the list of syntax trees (:class:`~sqlglot.expressions.Expression`). """ - return self._parse( - parse_method=self.__class__._parse_statement, raw_tokens=raw_tokens, sql=sql - ) + return self._parse(parse_method=self.__class__._parse_statement, raw_tokens=raw_tokens, sql=sql) def parse_into(self, expression_types, raw_tokens, sql=None): for expression_type in ensure_list(expression_types): @@ -532,21 +538,13 @@ class Parser: for k in expression.args: if k not in expression.arg_types: - self.raise_error( - f"Unexpected keyword: '{k}' for {expression.__class__}" - ) + self.raise_error(f"Unexpected keyword: '{k}' for {expression.__class__}") for k, mandatory in expression.arg_types.items(): v = expression.args.get(k) if mandatory and (v is None or (isinstance(v, list) and not v)): - self.raise_error( - f"Required keyword: '{k}' missing for {expression.__class__}" - ) + self.raise_error(f"Required keyword: '{k}' missing for {expression.__class__}") - if ( - args - and len(args) > len(expression.arg_types) - and not expression.is_var_len_args - ): + if args and len(args) > len(expression.arg_types) and not expression.is_var_len_args: self.raise_error( f"The number of provided arguments ({len(args)}) is greater than " f"the maximum number of supported arguments ({len(expression.arg_types)})" @@ -594,11 +592,7 @@ class Parser: ) expression = self._parse_expression() - expression = ( - self._parse_set_operations(expression) - if expression - else self._parse_select() - ) + expression = self._parse_set_operations(expression) if expression else self._parse_select() self._parse_query_modifiers(expression) return expression @@ -618,11 +612,7 @@ class Parser: ) def _parse_exists(self, not_=False): - return ( - self._match(TokenType.IF) - and (not not_ or self._match(TokenType.NOT)) - and self._match(TokenType.EXISTS) - ) + return self._match(TokenType.IF) and (not not_ or self._match(TokenType.NOT)) and self._match(TokenType.EXISTS) def _parse_create(self): replace = self._match(TokenType.OR) and self._match(TokenType.REPLACE) @@ -647,11 +637,9 @@ class Parser: this = self._parse_index() elif create_token.token_type in (TokenType.TABLE, TokenType.VIEW): this = self._parse_table(schema=True) - properties = self._parse_properties( - this if isinstance(this, exp.Schema) else None - ) + properties = self._parse_properties(this if isinstance(this, exp.Schema) else None) if self._match(TokenType.ALIAS): - expression = self._parse_select() + expression = self._parse_select(nested=True) return self.expression( exp.Create, @@ -682,9 +670,7 @@ class Parser: if schema and not isinstance(value, exp.Schema): columns = {v.name.upper() for v in value.expressions} partitions = [ - expression - for expression in schema.expressions - if expression.this.name.upper() in columns + expression for expression in schema.expressions if expression.this.name.upper() in columns ] schema.set( "expressions", @@ -811,7 +797,7 @@ class Parser: this=self._parse_table(schema=True), exists=self._parse_exists(), partition=self._parse_partition(), - expression=self._parse_select(), + expression=self._parse_select(nested=True), overwrite=overwrite, ) @@ -829,8 +815,7 @@ class Parser: exp.Update, **{ "this": self._parse_table(schema=True), - "expressions": self._match(TokenType.SET) - and self._parse_csv(self._parse_equality), + "expressions": self._match(TokenType.SET) and self._parse_csv(self._parse_equality), "from": self._parse_from(), "where": self._parse_where(), }, @@ -865,7 +850,7 @@ class Parser: this=table, lazy=lazy, options=options, - expression=self._parse_select(), + expression=self._parse_select(nested=True), ) def _parse_partition(self): @@ -894,9 +879,7 @@ class Parser: self._match_r_paren() return self.expression(exp.Tuple, expressions=expressions) - def _parse_select(self, table=None): - index = self._index - + def _parse_select(self, nested=False, table=False): if self._match(TokenType.SELECT): hint = self._parse_hint() all_ = self._match(TokenType.ALL) @@ -912,9 +895,7 @@ class Parser: self.raise_error("Cannot specify both ALL and DISTINCT after SELECT") limit = self._parse_limit(top=True) - expressions = self._parse_csv( - lambda: self._parse_annotation(self._parse_expression()) - ) + expressions = self._parse_csv(lambda: self._parse_annotation(self._parse_expression())) this = self.expression( exp.Select, @@ -960,19 +941,13 @@ class Parser: ) else: self.raise_error(f"{this.key} does not support CTE") - elif self._match(TokenType.L_PAREN): - this = self._parse_table() if table else self._parse_select() - - if this: - self._parse_query_modifiers(this) - self._match_r_paren() - this = self._parse_subquery(this) - else: - self._retreat(index) + elif (table or nested) and self._match(TokenType.L_PAREN): + this = self._parse_table() if table else self._parse_select(nested=True) + self._parse_query_modifiers(this) + self._match_r_paren() + this = self._parse_subquery(this) elif self._match(TokenType.VALUES): - this = self.expression( - exp.Values, expressions=self._parse_csv(self._parse_value) - ) + this = self.expression(exp.Values, expressions=self._parse_csv(self._parse_value)) alias = self._parse_table_alias() if alias: this = self.expression(exp.Subquery, this=this, alias=alias) @@ -1001,7 +976,7 @@ class Parser: def _parse_table_alias(self): any_token = self._match(TokenType.ALIAS) - alias = self._parse_id_var(any_token) + alias = self._parse_id_var(any_token=any_token, tokens=self.TABLE_ALIAS_TOKENS) columns = None if self._match(TokenType.L_PAREN): @@ -1021,9 +996,24 @@ class Parser: return self.expression(exp.Subquery, this=this, alias=self._parse_table_alias()) def _parse_query_modifiers(self, this): - if not isinstance(this, (exp.Subquery, exp.Subqueryable)): + if not isinstance(this, self.MODIFIABLES): return + table = isinstance(this, exp.Table) + + while True: + lateral = self._parse_lateral() + join = self._parse_join() + comma = None if table else self._match(TokenType.COMMA) + if lateral: + this.append("laterals", lateral) + if join: + this.append("joins", join) + if comma: + this.args["from"].append("expressions", self._parse_table()) + if not (lateral or join or comma): + break + for key, parser in self.QUERY_MODIFIER_PARSERS.items(): expression = parser(self) @@ -1032,9 +1022,7 @@ class Parser: def _parse_annotation(self, expression): if self._match(TokenType.ANNOTATION): - return self.expression( - exp.Annotation, this=self._prev.text, expression=expression - ) + return self.expression(exp.Annotation, this=self._prev.text, expression=expression) return expression @@ -1052,16 +1040,16 @@ class Parser: return self.expression(exp.From, expressions=self._parse_csv(self._parse_table)) - def _parse_laterals(self): - return self._parse_all(self._parse_lateral) - def _parse_lateral(self): if not self._match(TokenType.LATERAL): return None - if not self._match(TokenType.VIEW): - self.raise_error("Expected VIEW after LATERAL") + subquery = self._parse_select(table=True) + if subquery: + return self.expression(exp.Lateral, this=subquery) + + self._match(TokenType.VIEW) outer = self._match(TokenType.OUTER) return self.expression( @@ -1071,31 +1059,27 @@ class Parser: alias=self.expression( exp.TableAlias, this=self._parse_id_var(any_token=False), - columns=( - self._parse_csv(self._parse_id_var) - if self._match(TokenType.ALIAS) - else None - ), + columns=(self._parse_csv(self._parse_id_var) if self._match(TokenType.ALIAS) else None), ), ) - def _parse_joins(self): - return self._parse_all(self._parse_join) - def _parse_join_side_and_kind(self): return ( + self._match(TokenType.NATURAL) and self._prev, self._match_set(self.JOIN_SIDES) and self._prev, self._match_set(self.JOIN_KINDS) and self._prev, ) def _parse_join(self): - side, kind = self._parse_join_side_and_kind() + natural, side, kind = self._parse_join_side_and_kind() if not self._match(TokenType.JOIN): return None kwargs = {"this": self._parse_table()} + if natural: + kwargs["natural"] = True if side: kwargs["side"] = side.text if kind: @@ -1120,6 +1104,11 @@ class Parser: ) def _parse_table(self, schema=False): + lateral = self._parse_lateral() + + if lateral: + return lateral + unnest = self._parse_unnest() if unnest: @@ -1172,9 +1161,7 @@ class Parser: expressions = self._parse_csv(self._parse_column) self._match_r_paren() - ordinality = bool( - self._match(TokenType.WITH) and self._match(TokenType.ORDINALITY) - ) + ordinality = bool(self._match(TokenType.WITH) and self._match(TokenType.ORDINALITY)) alias = self._parse_table_alias() @@ -1280,17 +1267,13 @@ class Parser: if not self._match(TokenType.ORDER_BY): return this - return self.expression( - exp.Order, this=this, expressions=self._parse_csv(self._parse_ordered) - ) + return self.expression(exp.Order, this=this, expressions=self._parse_csv(self._parse_ordered)) def _parse_sort(self, token_type, exp_class): if not self._match(token_type): return None - return self.expression( - exp_class, expressions=self._parse_csv(self._parse_ordered) - ) + return self.expression(exp_class, expressions=self._parse_csv(self._parse_ordered)) def _parse_ordered(self): this = self._parse_conjunction() @@ -1305,22 +1288,17 @@ class Parser: if ( not explicitly_null_ordered and ( - (asc and self.null_ordering == "nulls_are_small") - or (desc and self.null_ordering != "nulls_are_small") + (asc and self.null_ordering == "nulls_are_small") or (desc and self.null_ordering != "nulls_are_small") ) and self.null_ordering != "nulls_are_last" ): nulls_first = True - return self.expression( - exp.Ordered, this=this, desc=desc, nulls_first=nulls_first - ) + return self.expression(exp.Ordered, this=this, desc=desc, nulls_first=nulls_first) def _parse_limit(self, this=None, top=False): if self._match(TokenType.TOP if top else TokenType.LIMIT): - return self.expression( - exp.Limit, this=this, expression=self._parse_number() - ) + return self.expression(exp.Limit, this=this, expression=self._parse_number()) if self._match(TokenType.FETCH): direction = self._match_set((TokenType.FIRST, TokenType.NEXT)) direction = self._prev.text if direction else "FIRST" @@ -1354,7 +1332,7 @@ class Parser: expression, this=this, distinct=self._match(TokenType.DISTINCT) or not self._match(TokenType.ALL), - expression=self._parse_select(), + expression=self._parse_select(nested=True), ) def _parse_expression(self): @@ -1396,9 +1374,7 @@ class Parser: this = self.expression(exp.In, this=this, unnest=unnest) else: self._match_l_paren() - expressions = self._parse_csv( - lambda: self._parse_select() or self._parse_expression() - ) + expressions = self._parse_csv(lambda: self._parse_select() or self._parse_expression()) if len(expressions) == 1 and isinstance(expressions[0], exp.Subqueryable): this = self.expression(exp.In, this=this, query=expressions[0]) @@ -1430,13 +1406,9 @@ class Parser: expression=self._parse_term(), ) elif self._match_pair(TokenType.LT, TokenType.LT): - this = self.expression( - exp.BitwiseLeftShift, this=this, expression=self._parse_term() - ) + this = self.expression(exp.BitwiseLeftShift, this=this, expression=self._parse_term()) elif self._match_pair(TokenType.GT, TokenType.GT): - this = self.expression( - exp.BitwiseRightShift, this=this, expression=self._parse_term() - ) + this = self.expression(exp.BitwiseRightShift, this=this, expression=self._parse_term()) else: break @@ -1524,7 +1496,7 @@ class Parser: self.raise_error("Expecting >") if type_token in self.TIMESTAMPS: - tz = self._match(TokenType.WITH_TIME_ZONE) + tz = self._match(TokenType.WITH_TIME_ZONE) or type_token == TokenType.TIMESTAMPTZ self._match(TokenType.WITHOUT_TIME_ZONE) if tz: return exp.DataType( @@ -1594,16 +1566,14 @@ class Parser: if query: expressions = [query] else: - expressions = self._parse_csv( - lambda: self._parse_alias(self._parse_conjunction(), explicit=True) - ) + expressions = self._parse_csv(lambda: self._parse_alias(self._parse_conjunction(), explicit=True)) this = list_get(expressions, 0) self._parse_query_modifiers(this) self._match_r_paren() if isinstance(this, exp.Subqueryable): - return self._parse_subquery(this) + return self._parse_set_operations(self._parse_subquery(this)) if len(expressions) > 1: return self.expression(exp.Tuple, expressions=expressions) return self.expression(exp.Paren, this=this) @@ -1611,11 +1581,7 @@ class Parser: return None def _parse_field(self, any_token=False): - return ( - self._parse_primary() - or self._parse_function() - or self._parse_id_var(any_token) - ) + return self._parse_primary() or self._parse_function() or self._parse_id_var(any_token) def _parse_function(self): if not self._curr: @@ -1628,21 +1594,22 @@ class Parser: if not self._next or self._next.token_type != TokenType.L_PAREN: if token_type in self.NO_PAREN_FUNCTIONS: - return self.expression( - self._advance() or self.NO_PAREN_FUNCTIONS[token_type] - ) + return self.expression(self._advance() or self.NO_PAREN_FUNCTIONS[token_type]) return None if token_type not in self.FUNC_TOKENS: return None - if self._match_set(self.FUNCTION_PARSERS): - self._advance() - this = self.FUNCTION_PARSERS[token_type](self, token_type) + this = self._curr.text + upper = this.upper() + self._advance(2) + + parser = self.FUNCTION_PARSERS.get(upper) + + if parser: + this = parser(self) else: subquery_predicate = self.SUBQUERY_PREDICATES.get(token_type) - this = self._curr.text - self._advance(2) if subquery_predicate and self._curr.token_type in ( TokenType.SELECT, @@ -1652,7 +1619,7 @@ class Parser: self._match_r_paren() return this - function = self.FUNCTIONS.get(this.upper()) + function = self.FUNCTIONS.get(upper) args = self._parse_csv(self._parse_lambda) if function: @@ -1700,10 +1667,7 @@ class Parser: self._retreat(index) return this - args = self._parse_csv( - lambda: self._parse_constraint() - or self._parse_column_def(self._parse_field()) - ) + args = self._parse_csv(lambda: self._parse_constraint() or self._parse_column_def(self._parse_field(True))) self._match_r_paren() return self.expression(exp.Schema, this=this, expressions=args) @@ -1720,12 +1684,9 @@ class Parser: break constraints.append(constraint) - return self.expression( - exp.ColumnDef, this=this, kind=kind, constraints=constraints - ) + return self.expression(exp.ColumnDef, this=this, kind=kind, constraints=constraints) def _parse_column_constraint(self): - kind = None this = None if self._match(TokenType.CONSTRAINT): @@ -1735,28 +1696,28 @@ class Parser: kind = exp.AutoIncrementColumnConstraint() elif self._match(TokenType.CHECK): self._match_l_paren() - kind = self.expression( - exp.CheckColumnConstraint, this=self._parse_conjunction() - ) + kind = self.expression(exp.CheckColumnConstraint, this=self._parse_conjunction()) self._match_r_paren() elif self._match(TokenType.COLLATE): kind = self.expression(exp.CollateColumnConstraint, this=self._parse_var()) elif self._match(TokenType.DEFAULT): - kind = self.expression( - exp.DefaultColumnConstraint, this=self._parse_field() - ) - elif self._match(TokenType.NOT) and self._match(TokenType.NULL): + kind = self.expression(exp.DefaultColumnConstraint, this=self._parse_field()) + elif self._match_pair(TokenType.NOT, TokenType.NULL): kind = exp.NotNullColumnConstraint() elif self._match(TokenType.SCHEMA_COMMENT): - kind = self.expression( - exp.CommentColumnConstraint, this=self._parse_string() - ) + kind = self.expression(exp.CommentColumnConstraint, this=self._parse_string()) elif self._match(TokenType.PRIMARY_KEY): kind = exp.PrimaryKeyColumnConstraint() elif self._match(TokenType.UNIQUE): kind = exp.UniqueColumnConstraint() - - if kind is None: + elif self._match(TokenType.GENERATED): + if self._match(TokenType.BY_DEFAULT): + kind = self.expression(exp.GeneratedAsIdentityColumnConstraint, this=False) + else: + self._match(TokenType.ALWAYS) + kind = self.expression(exp.GeneratedAsIdentityColumnConstraint, this=True) + self._match_pair(TokenType.ALIAS, TokenType.IDENTITY) + else: return None return self.expression(exp.ColumnConstraint, this=this, kind=kind) @@ -1864,9 +1825,7 @@ class Parser: if not self._match(TokenType.END): self.raise_error("Expected END after CASE", self._prev) - return self._parse_window( - self.expression(exp.Case, this=expression, ifs=ifs, default=default) - ) + return self._parse_window(self.expression(exp.Case, this=expression, ifs=ifs, default=default)) def _parse_if(self): if self._match(TokenType.L_PAREN): @@ -1889,7 +1848,7 @@ class Parser: if not self._match(TokenType.FROM): self.raise_error("Expected FROM after EXTRACT", self._prev) - return self.expression(exp.Extract, this=this, expression=self._parse_type()) + return self.expression(exp.Extract, this=this, expression=self._parse_bitwise()) def _parse_cast(self, strict): this = self._parse_conjunction() @@ -1917,12 +1876,54 @@ class Parser: to = None return self.expression(exp.Cast, this=this, to=to) + def _parse_substring(self): + # Postgres supports the form: substring(string [from int] [for int]) + # https://www.postgresql.org/docs/9.1/functions-string.html @ Table 9-6 + + args = self._parse_csv(self._parse_bitwise) + + if self._match(TokenType.FROM): + args.append(self._parse_bitwise()) + if self._match(TokenType.FOR): + args.append(self._parse_bitwise()) + + this = exp.Substring.from_arg_list(args) + self.validate_expression(this, args) + + return this + + def _parse_trim(self): + # https://www.w3resource.com/sql/character-functions/trim.php + # https://docs.oracle.com/javadb/10.8.3.0/ref/rreftrimfunc.html + + position = None + collation = None + + if self._match_set(self.TRIM_TYPES): + position = self._prev.text.upper() + + expression = self._parse_term() + if self._match(TokenType.FROM): + this = self._parse_term() + else: + this = expression + expression = None + + if self._match(TokenType.COLLATE): + collation = self._parse_term() + + return self.expression( + exp.Trim, + this=this, + position=position, + expression=expression, + collation=collation, + ) + def _parse_window(self, this, alias=False): if self._match(TokenType.FILTER): self._match_l_paren() - this = self.expression( - exp.Filter, this=this, expression=self._parse_where() - ) + this = self.expression(exp.Filter, this=this, expression=self._parse_where()) self._match_r_paren() if self._match(TokenType.WITHIN_GROUP): @@ -1935,6 +1936,25 @@ class Parser: self._match_r_paren() return this + # SQL spec defines an optional [ { IGNORE | RESPECT } NULLS ] OVER + # Some dialects choose to implement and some do not. + # https://dev.mysql.com/doc/refman/8.0/en/window-function-descriptions.html + + # There is some code above in _parse_lambda that handles + # SELECT FIRST_VALUE(TABLE.COLUMN IGNORE|RESPECT NULLS) OVER ... + + # The below changes handle + # SELECT FIRST_VALUE(TABLE.COLUMN) IGNORE|RESPECT NULLS OVER ... + + # Oracle allows both formats + # (https://docs.oracle.com/en/database/oracle/oracle-database/19/sqlrf/img_text/first_value.html) + # and Snowflake chose to do the same for familiarity + # https://docs.snowflake.com/en/sql-reference/functions/first_value.html#usage-notes + if self._match(TokenType.IGNORE_NULLS): + this = self.expression(exp.IgnoreNulls, this=this) + elif self._match(TokenType.RESPECT_NULLS): + this = self.expression(exp.RespectNulls, this=this) + # bigquery select from window x AS (partition by ...) if alias: self._match(TokenType.ALIAS) @@ -1992,13 +2012,9 @@ class Parser: self._match(TokenType.BETWEEN) return { - "value": ( - self._match_set((TokenType.UNBOUNDED, TokenType.CURRENT_ROW)) - and self._prev.text - ) + "value": (self._match_set((TokenType.UNBOUNDED, TokenType.CURRENT_ROW)) and self._prev.text) or self._parse_bitwise(), - "side": self._match_set((TokenType.PRECEDING, TokenType.FOLLOWING)) - and self._prev.text, + "side": self._match_set((TokenType.PRECEDING, TokenType.FOLLOWING)) and self._prev.text, } def _parse_alias(self, this, explicit=False): @@ -2023,22 +2039,16 @@ class Parser: return this - def _parse_id_var(self, any_token=True): + def _parse_id_var(self, any_token=True, tokens=None): identifier = self._parse_identifier() if identifier: return identifier - if ( - any_token - and self._curr - and self._curr.token_type not in self.RESERVED_KEYWORDS - ): + if any_token and self._curr and self._curr.token_type not in self.RESERVED_KEYWORDS: return self._advance() or exp.Identifier(this=self._prev.text, quoted=False) - return self._match_set(self.ID_VAR_TOKENS) and exp.Identifier( - this=self._prev.text, quoted=False - ) + return self._match_set(tokens or self.ID_VAR_TOKENS) and exp.Identifier(this=self._prev.text, quoted=False) def _parse_string(self): if self._match(TokenType.STRING): @@ -2077,9 +2087,7 @@ class Parser: def _parse_star(self): if self._match(TokenType.STAR): - return exp.Star( - **{"except": self._parse_except(), "replace": self._parse_replace()} - ) + return exp.Star(**{"except": self._parse_except(), "replace": self._parse_replace()}) return None def _parse_placeholder(self): @@ -2117,15 +2125,10 @@ class Parser: this = parse() while self._match_set(expressions): - this = self.expression( - expressions[self._prev.token_type], this=this, expression=parse() - ) + this = self.expression(expressions[self._prev.token_type], this=this, expression=parse()) return this - def _parse_all(self, parse): - return list(iter(parse, None)) - def _parse_wrapped_id_vars(self): self._match_l_paren() expressions = self._parse_csv(self._parse_id_var) @@ -2156,10 +2159,7 @@ class Parser: if not self._curr or not self._next: return None - if ( - self._curr.token_type == token_type_a - and self._next.token_type == token_type_b - ): + if self._curr.token_type == token_type_a and self._next.token_type == token_type_b: if advance: self._advance(2) return True -- cgit v1.2.3