From 3742f86d166160ca3843872ebecb6f30c51f6085 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Mon, 14 Aug 2023 12:12:19 +0200 Subject: Merging upstream version 17.12.0. Signed-off-by: Daniel Baumann --- sqlglot/parser.py | 118 +++++++++++++++++++++++++++++++++--------------------- 1 file changed, 73 insertions(+), 45 deletions(-) (limited to 'sqlglot/parser.py') diff --git a/sqlglot/parser.py b/sqlglot/parser.py index 35a1744..3db4453 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -102,15 +102,23 @@ class Parser(metaclass=_Parser): TokenType.CURRENT_USER: exp.CurrentUser, } + STRUCT_TYPE_TOKENS = { + TokenType.NESTED, + TokenType.STRUCT, + } + NESTED_TYPE_TOKENS = { TokenType.ARRAY, + TokenType.LOWCARDINALITY, TokenType.MAP, TokenType.NULLABLE, - TokenType.STRUCT, + *STRUCT_TYPE_TOKENS, } ENUM_TYPE_TOKENS = { TokenType.ENUM, + TokenType.ENUM8, + TokenType.ENUM16, } TYPE_TOKENS = { @@ -128,6 +136,7 @@ class Parser(metaclass=_Parser): TokenType.UINT128, TokenType.INT256, TokenType.UINT256, + TokenType.FIXEDSTRING, TokenType.FLOAT, TokenType.DOUBLE, TokenType.CHAR, @@ -145,6 +154,7 @@ class Parser(metaclass=_Parser): TokenType.JSONB, TokenType.INTERVAL, TokenType.TIME, + TokenType.TIMETZ, TokenType.TIMESTAMP, TokenType.TIMESTAMPTZ, TokenType.TIMESTAMPLTZ, @@ -187,7 +197,7 @@ class Parser(metaclass=_Parser): TokenType.INET, TokenType.IPADDRESS, TokenType.IPPREFIX, - TokenType.ENUM, + *ENUM_TYPE_TOKENS, *NESTED_TYPE_TOKENS, } @@ -384,11 +394,16 @@ class Parser(metaclass=_Parser): TokenType.STAR: exp.Mul, } - TIMESTAMPS = { + TIMES = { TokenType.TIME, + TokenType.TIMETZ, + } + + TIMESTAMPS = { TokenType.TIMESTAMP, TokenType.TIMESTAMPTZ, TokenType.TIMESTAMPLTZ, + *TIMES, } SET_OPERATIONS = { @@ -1165,6 +1180,8 @@ class Parser(metaclass=_Parser): def _parse_create(self) -> exp.Create | exp.Command: # Note: this can't be None because we've matched a statement parser start = self._prev + comments = self._prev_comments + replace = start.text.upper() == "REPLACE" or self._match_pair( TokenType.OR, TokenType.REPLACE ) @@ -1273,6 +1290,7 @@ class Parser(metaclass=_Parser): return self.expression( exp.Create, + comments=comments, this=this, kind=create_token.text, replace=replace, @@ -2338,7 +2356,8 @@ class Parser(metaclass=_Parser): kwargs["this"].set("joins", joins) - return self.expression(exp.Join, **kwargs) + comments = [c for token in (method, side, kind) if token for c in token.comments] + return self.expression(exp.Join, comments=comments, **kwargs) def _parse_index( self, @@ -2619,11 +2638,18 @@ class Parser(metaclass=_Parser): def _parse_pivot(self) -> t.Optional[exp.Pivot]: index = self._index + include_nulls = None if self._match(TokenType.PIVOT): unpivot = False elif self._match(TokenType.UNPIVOT): unpivot = True + + # https://docs.databricks.com/en/sql/language-manual/sql-ref-syntax-qry-select-unpivot.html#syntax + if self._match_text_seq("INCLUDE", "NULLS"): + include_nulls = True + elif self._match_text_seq("EXCLUDE", "NULLS"): + include_nulls = False else: return None @@ -2654,7 +2680,13 @@ class Parser(metaclass=_Parser): self._match_r_paren() - pivot = self.expression(exp.Pivot, expressions=expressions, field=field, unpivot=unpivot) + pivot = self.expression( + exp.Pivot, + expressions=expressions, + field=field, + unpivot=unpivot, + include_nulls=include_nulls, + ) if not self._match_set((TokenType.PIVOT, TokenType.UNPIVOT), advance=False): pivot.set("alias", self._parse_table_alias()) @@ -3096,7 +3128,7 @@ class Parser(metaclass=_Parser): return self.expression(exp.PseudoType, this=self._prev.text) nested = type_token in self.NESTED_TYPE_TOKENS - is_struct = type_token == TokenType.STRUCT + is_struct = type_token in self.STRUCT_TYPE_TOKENS expressions = None maybe_func = False @@ -3108,7 +3140,7 @@ class Parser(metaclass=_Parser): lambda: self._parse_types(check_func=check_func, schema=schema) ) elif type_token in self.ENUM_TYPE_TOKENS: - expressions = self._parse_csv(self._parse_primary) + expressions = self._parse_csv(self._parse_equality) else: expressions = self._parse_csv(self._parse_type_size) @@ -3118,29 +3150,9 @@ class Parser(metaclass=_Parser): maybe_func = True - if self._match_pair(TokenType.L_BRACKET, TokenType.R_BRACKET): - this = exp.DataType( - this=exp.DataType.Type.ARRAY, - expressions=[ - exp.DataType( - this=exp.DataType.Type[type_token.value], - expressions=expressions, - nested=nested, - ) - ], - nested=True, - ) - - while self._match_pair(TokenType.L_BRACKET, TokenType.R_BRACKET): - this = exp.DataType(this=exp.DataType.Type.ARRAY, expressions=[this], nested=True) - - return this - - if self._match(TokenType.L_BRACKET): - self._retreat(index) - return None - + this: t.Optional[exp.Expression] = None values: t.Optional[t.List[t.Optional[exp.Expression]]] = None + if nested and self._match(TokenType.LT): if is_struct: expressions = self._parse_csv(self._parse_struct_types) @@ -3156,23 +3168,35 @@ class Parser(metaclass=_Parser): values = self._parse_csv(self._parse_conjunction) self._match_set((TokenType.R_BRACKET, TokenType.R_PAREN)) - value: t.Optional[exp.Expression] = None if type_token in self.TIMESTAMPS: if self._match_text_seq("WITH", "TIME", "ZONE"): maybe_func = False - value = exp.DataType(this=exp.DataType.Type.TIMESTAMPTZ, expressions=expressions) + tz_type = ( + exp.DataType.Type.TIMETZ + if type_token in self.TIMES + else exp.DataType.Type.TIMESTAMPTZ + ) + this = exp.DataType(this=tz_type, expressions=expressions) elif self._match_text_seq("WITH", "LOCAL", "TIME", "ZONE"): maybe_func = False - value = exp.DataType(this=exp.DataType.Type.TIMESTAMPLTZ, expressions=expressions) + this = exp.DataType(this=exp.DataType.Type.TIMESTAMPLTZ, expressions=expressions) elif self._match_text_seq("WITHOUT", "TIME", "ZONE"): maybe_func = False elif type_token == TokenType.INTERVAL: - unit = self._parse_var() + if self._match_text_seq("YEAR", "TO", "MONTH"): + span: t.Optional[t.List[exp.Expression]] = [exp.IntervalYearToMonthSpan()] + elif self._match_text_seq("DAY", "TO", "SECOND"): + span = [exp.IntervalDayToSecondSpan()] + else: + span = None + unit = not span and self._parse_var() if not unit: - value = self.expression(exp.DataType, this=exp.DataType.Type.INTERVAL) + this = self.expression( + exp.DataType, this=exp.DataType.Type.INTERVAL, expressions=span + ) else: - value = self.expression(exp.Interval, unit=unit) + this = self.expression(exp.Interval, unit=unit) if maybe_func and check_func: index2 = self._index @@ -3184,16 +3208,19 @@ class Parser(metaclass=_Parser): self._retreat(index2) - if value: - return value + if not this: + this = exp.DataType( + this=exp.DataType.Type[type_token.value], + expressions=expressions, + nested=nested, + values=values, + prefix=prefix, + ) - return exp.DataType( - this=exp.DataType.Type[type_token.value], - expressions=expressions, - nested=nested, - values=values, - prefix=prefix, - ) + while self._match_pair(TokenType.L_BRACKET, TokenType.R_BRACKET): + this = exp.DataType(this=exp.DataType.Type.ARRAY, expressions=[this], nested=True) + + return this def _parse_struct_types(self) -> t.Optional[exp.Expression]: this = self._parse_type() or self._parse_id_var() @@ -3738,6 +3765,7 @@ class Parser(metaclass=_Parser): ifs = [] default = None + comments = self._prev_comments expression = self._parse_conjunction() while self._match(TokenType.WHEN): @@ -3753,7 +3781,7 @@ class Parser(metaclass=_Parser): self.raise_error("Expected END after CASE", self._prev) return self._parse_window( - self.expression(exp.Case, this=expression, ifs=ifs, default=default) + self.expression(exp.Case, comments=comments, this=expression, ifs=ifs, default=default) ) def _parse_if(self) -> t.Optional[exp.Expression]: -- cgit v1.2.3