diff options
Diffstat (limited to 'sqlglot/parser.py')
-rw-r--r-- | sqlglot/parser.py | 309 |
1 files changed, 231 insertions, 78 deletions
diff --git a/sqlglot/parser.py b/sqlglot/parser.py index 67ffd8f..c2cb3a1 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -41,11 +41,17 @@ def build_like(args: t.List) -> exp.Escape | exp.Like: def binary_range_parser( - expr_type: t.Type[exp.Expression], + expr_type: t.Type[exp.Expression], reverse_args: bool = False ) -> t.Callable[[Parser, t.Optional[exp.Expression]], t.Optional[exp.Expression]]: - return lambda self, this: self._parse_escape( - self.expression(expr_type, this=this, expression=self._parse_bitwise()) - ) + def _parse_binary_range( + self: Parser, this: t.Optional[exp.Expression] + ) -> t.Optional[exp.Expression]: + expression = self._parse_bitwise() + if reverse_args: + this, expression = expression, this + return self._parse_escape(self.expression(expr_type, this=this, expression=expression)) + + return _parse_binary_range def build_logarithm(args: t.List, dialect: Dialect) -> exp.Func: @@ -335,6 +341,8 @@ class Parser(metaclass=_Parser): TokenType.TABLE, TokenType.TAG, TokenType.VIEW, + TokenType.WAREHOUSE, + TokenType.STREAMLIT, } CREATABLES = { @@ -418,6 +426,7 @@ class Parser(metaclass=_Parser): TokenType.TRUE, TokenType.TRUNCATE, TokenType.UNIQUE, + TokenType.UNNEST, TokenType.UNPIVOT, TokenType.UPDATE, TokenType.USE, @@ -580,7 +589,7 @@ class Parser(metaclass=_Parser): exp.Lambda, this=self._replace_lambda( self._parse_conjunction(), - {node.name for node in expressions}, + expressions, ), expressions=expressions, ), @@ -1125,6 +1134,8 @@ class Parser(metaclass=_Parser): SELECT_START_TOKENS = {TokenType.L_PAREN, TokenType.WITH, TokenType.SELECT} + COPY_INTO_VARLEN_OPTIONS = {"FILE_FORMAT", "COPY_OPTIONS", "FORMAT_OPTIONS", "CREDENTIAL"} + STRICT_CAST = True PREFIXED_PIVOT_COLUMNS = False @@ -1160,6 +1171,9 @@ class Parser(metaclass=_Parser): # Whether the -> and ->> operators expect documents of type JSON (e.g. Postgres) JSON_ARROWS_REQUIRE_JSON_TYPE = False + # Whether the `:` operator is used to extract a value from a JSON document + COLON_IS_JSON_EXTRACT = False + # Whether or not a VALUES keyword needs to be followed by '(' to form a VALUES clause. # If this is True and '(' is not found, the keyword will be treated as an identifier VALUES_FOLLOWED_BY_PAREN = True @@ -1631,6 +1645,7 @@ class Parser(metaclass=_Parser): extend_props(self._parse_properties()) expression = self._match(TokenType.ALIAS) and self._parse_heredoc() + extend_props(self._parse_properties()) if not expression: if self._match(TokenType.COMMAND): @@ -1817,11 +1832,17 @@ class Parser(metaclass=_Parser): self._retreat(index) return self._parse_sequence_properties() - return self.expression( - exp.Property, - this=key.to_dot() if isinstance(key, exp.Column) else key, - value=self._parse_bitwise() or self._parse_var(any_token=True), - ) + # Transform the key to exp.Dot if it's dotted identifiers wrapped in exp.Column or to exp.Var otherwise + if isinstance(key, exp.Column): + key = key.to_dot() if len(key.parts) > 1 else exp.var(key.name) + + value = self._parse_bitwise() or self._parse_var(any_token=True) + + # Transform the value to exp.Var if it was parsed as exp.Column(exp.Identifier()) + if isinstance(value, exp.Column): + value = exp.var(value.name) + + return self.expression(exp.Property, this=key, value=value) def _parse_stored(self) -> exp.FileFormatProperty: self._match(TokenType.ALIAS) @@ -1840,7 +1861,7 @@ class Parser(metaclass=_Parser): ), ) - def _parse_unquoted_field(self): + def _parse_unquoted_field(self) -> t.Optional[exp.Expression]: field = self._parse_field() if isinstance(field, exp.Identifier) and not field.quoted: field = exp.var(field) @@ -2780,7 +2801,13 @@ class Parser(metaclass=_Parser): if not alias and not columns: return None - return self.expression(exp.TableAlias, this=alias, columns=columns) + table_alias = self.expression(exp.TableAlias, this=alias, columns=columns) + + # We bubble up comments from the Identifier to the TableAlias + if isinstance(alias, exp.Identifier): + table_alias.add_comments(alias.pop_comments()) + + return table_alias def _parse_subquery( self, this: t.Optional[exp.Expression], parse_alias: bool = True @@ -4047,7 +4074,7 @@ class Parser(metaclass=_Parser): return this return self.expression(exp.Escape, this=this, expression=self._parse_string()) - def _parse_interval(self, match_interval: bool = True) -> t.Optional[exp.Interval]: + def _parse_interval(self, match_interval: bool = True) -> t.Optional[exp.Add | exp.Interval]: index = self._index if not self._match(TokenType.INTERVAL) and match_interval: @@ -4077,23 +4104,33 @@ class Parser(metaclass=_Parser): if this and this.is_number: this = exp.Literal.string(this.name) elif this and this.is_string: - parts = this.name.split() - - if len(parts) == 2: + parts = exp.INTERVAL_STRING_RE.findall(this.name) + if len(parts) == 1: if unit: - # This is not actually a unit, it's something else (e.g. a "window side") - unit = None + # Unconsume the eagerly-parsed unit, since the real unit was part of the string self._retreat(self._index - 1) - this = exp.Literal.string(parts[0]) - unit = self.expression(exp.Var, this=parts[1].upper()) + this = exp.Literal.string(parts[0][0]) + unit = self.expression(exp.Var, this=parts[0][1].upper()) if self.INTERVAL_SPANS and self._match_text_seq("TO"): unit = self.expression( exp.IntervalSpan, this=unit, expression=self._parse_var(any_token=True, upper=True) ) - return self.expression(exp.Interval, this=this, unit=unit) + interval = self.expression(exp.Interval, this=this, unit=unit) + + index = self._index + self._match(TokenType.PLUS) + + # Convert INTERVAL 'val_1' unit_1 [+] ... [+] 'val_n' unit_n into a sum of intervals + if self._match_set((TokenType.STRING, TokenType.NUMBER), advance=False): + return self.expression( + exp.Add, this=interval, expression=self._parse_interval(match_interval=False) + ) + + self._retreat(index) + return interval def _parse_bitwise(self) -> t.Optional[exp.Expression]: this = self._parse_term() @@ -4155,39 +4192,50 @@ class Parser(metaclass=_Parser): return self.UNARY_PARSERS[self._prev.token_type](self) return self._parse_at_time_zone(self._parse_type()) - def _parse_type(self, parse_interval: bool = True) -> t.Optional[exp.Expression]: + def _parse_type( + self, parse_interval: bool = True, fallback_to_identifier: bool = False + ) -> t.Optional[exp.Expression]: interval = parse_interval and self._parse_interval() if interval: - # Convert INTERVAL 'val_1' unit_1 [+] ... [+] 'val_n' unit_n into a sum of intervals - while True: - index = self._index - self._match(TokenType.PLUS) - - if not self._match_set((TokenType.STRING, TokenType.NUMBER), advance=False): - self._retreat(index) - break - - interval = self.expression( # type: ignore - exp.Add, this=interval, expression=self._parse_interval(match_interval=False) - ) - return interval index = self._index data_type = self._parse_types(check_func=True, allow_identifiers=False) - this = self._parse_column() if data_type: + index2 = self._index + this = self._parse_primary() + if isinstance(this, exp.Literal): parser = self.TYPE_LITERAL_PARSERS.get(data_type.this) if parser: return parser(self, this, data_type) + return self.expression(exp.Cast, this=this, to=data_type) - if not data_type.expressions: - self._retreat(index) - return self._parse_column() - return self._parse_column_ops(data_type) + # The expressions arg gets set by the parser when we have something like DECIMAL(38, 0) + # in the input SQL. In that case, we'll produce these tokens: DECIMAL ( 38 , 0 ) + # + # If the index difference here is greater than 1, that means the parser itself must have + # consumed additional tokens such as the DECIMAL scale and precision in the above example. + # + # If it's not greater than 1, then it must be 1, because we've consumed at least the type + # keyword, meaning that the expressions arg of the DataType must have gotten set by a + # callable in the TYPE_CONVERTERS mapping. For example, Snowflake converts DECIMAL to + # DECIMAL(38, 0)) in order to facilitate the data type's transpilation. + # + # In these cases, we don't really want to return the converted type, but instead retreat + # and try to parse a Column or Identifier in the section below. + if data_type.expressions and index2 - index > 1: + self._retreat(index2) + return self._parse_column_ops(data_type) + + self._retreat(index) + + if fallback_to_identifier: + return self._parse_id_var() + + this = self._parse_column() return this and self._parse_column_ops(this) def _parse_type_size(self) -> t.Optional[exp.DataTypeParam]: @@ -4251,7 +4299,7 @@ class Parser(metaclass=_Parser): if self._match(TokenType.L_PAREN): if is_struct: - expressions = self._parse_csv(self._parse_struct_types) + expressions = self._parse_csv(lambda: self._parse_struct_types(type_required=True)) elif nested: expressions = self._parse_csv( lambda: self._parse_types( @@ -4352,8 +4400,26 @@ class Parser(metaclass=_Parser): elif expressions: this.set("expressions", expressions) - while self._match_pair(TokenType.L_BRACKET, TokenType.R_BRACKET): - this = exp.DataType(this=exp.DataType.Type.ARRAY, expressions=[this], nested=True) + index = self._index + + # Postgres supports the INT ARRAY[3] syntax as a synonym for INT[3] + matched_array = self._match(TokenType.ARRAY) + + while self._curr: + matched_l_bracket = self._match(TokenType.L_BRACKET) + if not matched_l_bracket and not matched_array: + break + + matched_array = False + values = self._parse_csv(self._parse_conjunction) or None + if values and not schema: + self._retreat(index) + break + + this = exp.DataType( + this=exp.DataType.Type.ARRAY, expressions=[this], values=values, nested=True + ) + self._match(TokenType.R_BRACKET) if self.TYPE_CONVERTER and isinstance(this.this, exp.DataType.Type): converter = self.TYPE_CONVERTER.get(this.this) @@ -4364,17 +4430,21 @@ class Parser(metaclass=_Parser): def _parse_struct_types(self, type_required: bool = False) -> t.Optional[exp.Expression]: index = self._index - this = self._parse_type(parse_interval=False) or self._parse_id_var() + this = ( + self._parse_type(parse_interval=False, fallback_to_identifier=True) + or self._parse_id_var() + ) self._match(TokenType.COLON) - column_def = self._parse_column_def(this) - if type_required and ( - (isinstance(this, exp.Column) and this.this is column_def) or this is column_def + if ( + type_required + and not isinstance(this, exp.DataType) + and not self._match_set(self.TYPE_TOKENS, advance=False) ): self._retreat(index) return self._parse_types() - return column_def + return self._parse_column_def(this) def _parse_at_time_zone(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: if not self._match_text_seq("AT", "TIME", "ZONE"): @@ -4401,6 +4471,47 @@ class Parser(metaclass=_Parser): return this + def _parse_colon_as_json_extract( + self, this: t.Optional[exp.Expression] + ) -> t.Optional[exp.Expression]: + casts = [] + json_path = [] + + while self._match(TokenType.COLON): + start_index = self._index + path = self._parse_column_ops(self._parse_field(any_token=True)) + + # The cast :: operator has a lower precedence than the extraction operator :, so + # we rearrange the AST appropriately to avoid casting the JSON path + while isinstance(path, exp.Cast): + casts.append(path.to) + path = path.this + + if casts: + dcolon_offset = next( + i + for i, t in enumerate(self._tokens[start_index:]) + if t.token_type == TokenType.DCOLON + ) + end_token = self._tokens[start_index + dcolon_offset - 1] + else: + end_token = self._prev + + if path: + json_path.append(self._find_sql(self._tokens[start_index], end_token)) + + if json_path: + this = self.expression( + exp.JSONExtract, + this=this, + expression=self.dialect.to_json_path(exp.Literal.string(".".join(json_path))), + ) + + while casts: + this = self.expression(exp.Cast, this=this, to=casts.pop()) + + return this + def _parse_column_ops(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: this = self._parse_bracket(this) @@ -4444,8 +4555,10 @@ class Parser(metaclass=_Parser): ) else: this = self.expression(exp.Dot, this=this, expression=field) + this = self._parse_bracket(this) - return this + + return self._parse_colon_as_json_extract(this) if self.COLON_IS_JSON_EXTRACT else this def _parse_primary(self) -> t.Optional[exp.Expression]: if self._match_set(self.PRIMARY_PARSERS): @@ -4680,18 +4793,21 @@ class Parser(metaclass=_Parser): return self.expression(exp.SessionParameter, this=this, kind=kind) + def _parse_lambda_arg(self) -> t.Optional[exp.Expression]: + return self._parse_id_var() + def _parse_lambda(self, alias: bool = False) -> t.Optional[exp.Expression]: index = self._index if self._match(TokenType.L_PAREN): expressions = t.cast( - t.List[t.Optional[exp.Expression]], self._parse_csv(self._parse_id_var) + t.List[t.Optional[exp.Expression]], self._parse_csv(self._parse_lambda_arg) ) if not self._match(TokenType.R_PAREN): self._retreat(index) else: - expressions = [self._parse_id_var()] + expressions = [self._parse_lambda_arg()] if self._match_set(self.LAMBDAS): return self.LAMBDAS[self._prev.token_type](self, expressions) @@ -5964,7 +6080,19 @@ class Parser(metaclass=_Parser): return self.expression(exp.AlterColumn, this=column, default=self._parse_conjunction()) if self._match(TokenType.COMMENT): return self.expression(exp.AlterColumn, this=column, comment=self._parse_string()) - + if self._match_text_seq("DROP", "NOT", "NULL"): + return self.expression( + exp.AlterColumn, + this=column, + drop=True, + allow_null=True, + ) + if self._match_text_seq("SET", "NOT", "NULL"): + return self.expression( + exp.AlterColumn, + this=column, + allow_null=False, + ) self._match_text_seq("SET", "DATA") self._match_text_seq("TYPE") return self.expression( @@ -6182,8 +6310,10 @@ class Parser(metaclass=_Parser): return None right = self._parse_statement() or self._parse_id_var() - this = self.expression(exp.EQ, this=left, expression=right) + if isinstance(right, (exp.Column, exp.Identifier)): + right = exp.var(right.name) + this = self.expression(exp.EQ, this=left, expression=right) return self.expression(exp.SetItem, this=this, kind=kind) def _parse_set_transaction(self, global_: bool = False) -> exp.Expression: @@ -6433,14 +6563,25 @@ class Parser(metaclass=_Parser): return True def _replace_lambda( - self, node: t.Optional[exp.Expression], lambda_variables: t.Set[str] + self, node: t.Optional[exp.Expression], expressions: t.List[exp.Expression] ) -> t.Optional[exp.Expression]: if not node: return node + lambda_types = {e.name: e.args.get("to") or False for e in expressions} + for column in node.find_all(exp.Column): - if column.parts[0].name in lambda_variables: + typ = lambda_types.get(column.parts[0].name) + if typ is not None: dot_or_id = column.to_dot() if column.table else column.this + + if typ: + dot_or_id = self.expression( + exp.Cast, + this=dot_or_id, + to=typ, + ) + parent = column.parent while isinstance(parent, exp.Dot): @@ -6516,12 +6657,23 @@ class Parser(metaclass=_Parser): return self.expression(exp.WithOperator, this=this, op=op) def _parse_wrapped_options(self) -> t.List[t.Optional[exp.Expression]]: - opts = [] self._match(TokenType.EQ) self._match(TokenType.L_PAREN) + + opts: t.List[t.Optional[exp.Expression]] = [] while self._curr and not self._match(TokenType.R_PAREN): - opts.append(self._parse_conjunction()) + if self._match_text_seq("FORMAT_NAME", "="): + # The FORMAT_NAME can be set to an identifier for Snowflake and T-SQL, + # so we parse it separately to use _parse_field() + prop = self.expression( + exp.Property, this=exp.var("FORMAT_NAME"), value=self._parse_field() + ) + opts.append(prop) + else: + opts.append(self._parse_property()) + self._match(TokenType.COMMA) + return opts def _parse_copy_parameters(self) -> t.List[exp.CopyParameter]: @@ -6529,37 +6681,38 @@ class Parser(metaclass=_Parser): options = [] while self._curr and not self._match(TokenType.R_PAREN, advance=False): - option = self._parse_unquoted_field() - value = None + option = self._parse_var(any_token=True) + prev = self._prev.text.upper() - # Some options are defined as functions with the values as params - if not isinstance(option, exp.Func): - prev = self._prev.text.upper() - # Different dialects might separate options and values by white space, "=" and "AS" - self._match(TokenType.EQ) - self._match(TokenType.ALIAS) + # Different dialects might separate options and values by white space, "=" and "AS" + self._match(TokenType.EQ) + self._match(TokenType.ALIAS) - if prev == "FILE_FORMAT" and self._match(TokenType.L_PAREN): - # Snowflake FILE_FORMAT case - value = self._parse_wrapped_options() - else: - value = self._parse_unquoted_field() + param = self.expression(exp.CopyParameter, this=option) - param = self.expression(exp.CopyParameter, this=option, expression=value) - options.append(param) + if prev in self.COPY_INTO_VARLEN_OPTIONS and self._match( + TokenType.L_PAREN, advance=False + ): + # Snowflake FILE_FORMAT case, Databricks COPY & FORMAT options + param.set("expressions", self._parse_wrapped_options()) + elif prev == "FILE_FORMAT": + # T-SQL's external file format case + param.set("expression", self._parse_field()) + else: + param.set("expression", self._parse_unquoted_field()) - if sep: - self._match(sep) + options.append(param) + self._match(sep) return options def _parse_credentials(self) -> t.Optional[exp.Credentials]: expr = self.expression(exp.Credentials) - if self._match_text_seq("STORAGE_INTEGRATION", advance=False): - expr.set("storage", self._parse_conjunction()) + if self._match_text_seq("STORAGE_INTEGRATION", "="): + expr.set("storage", self._parse_field()) if self._match_text_seq("CREDENTIALS"): - # Snowflake supports CREDENTIALS = (...), while Redshift CREDENTIALS <string> + # Snowflake case: CREDENTIALS = (...), Redshift case: CREDENTIALS <string> creds = ( self._parse_wrapped_options() if self._match(TokenType.EQ) else self._parse_field() ) @@ -6582,7 +6735,7 @@ class Parser(metaclass=_Parser): self._match(TokenType.INTO) this = ( - self._parse_conjunction() + self._parse_select(nested=True, parse_subquery_alias=False) if self._match(TokenType.L_PAREN, advance=False) else self._parse_table(schema=True) ) |