summaryrefslogtreecommitdiffstats
path: root/sqlglot/parser.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/parser.py')
-rw-r--r--sqlglot/parser.py272
1 files changed, 170 insertions, 102 deletions
diff --git a/sqlglot/parser.py b/sqlglot/parser.py
index c97b19a..42777d1 100644
--- a/sqlglot/parser.py
+++ b/sqlglot/parser.py
@@ -40,22 +40,23 @@ class _Parser(type):
class Parser(metaclass=_Parser):
"""
- Parser consumes a list of tokens produced by the :class:`~sqlglot.tokens.Tokenizer`
- and produces a parsed syntax tree.
-
- Args
- error_level (ErrorLevel): the desired error level. Default: ErrorLevel.RAISE.
- error_message_context (int): determines the amount of context to capture from
- a query string when displaying the error message (in number of characters).
+ Parser consumes a list of tokens produced by the `sqlglot.tokens.Tokenizer` and produces
+ a parsed syntax tree.
+
+ Args:
+ error_level: the desired error level.
+ Default: ErrorLevel.RAISE
+ error_message_context: determines the amount of context to capture from a
+ query string when displaying the error message (in number of characters).
Default: 50.
- index_offset (int): Index offset for arrays eg ARRAY[0] vs ARRAY[1] as the head of a list
+ index_offset: Index offset for arrays eg ARRAY[0] vs ARRAY[1] as the head of a list.
Default: 0
- alias_post_tablesample (bool): If the table alias comes after tablesample
+ alias_post_tablesample: If the table alias comes after tablesample.
Default: False
- max_errors (int): Maximum number of error messages to include in a raised ParseError.
+ max_errors: Maximum number of error messages to include in a raised ParseError.
This is only relevant if error_level is ErrorLevel.RAISE.
Default: 3
- null_ordering (str): Indicates the default null ordering method to use if not explicitly set.
+ null_ordering: Indicates the default null ordering method to use if not explicitly set.
Options are "nulls_are_small", "nulls_are_large", "nulls_are_last".
Default: "nulls_are_small"
"""
@@ -109,6 +110,8 @@ class Parser(metaclass=_Parser):
TokenType.TEXT,
TokenType.MEDIUMTEXT,
TokenType.LONGTEXT,
+ TokenType.MEDIUMBLOB,
+ TokenType.LONGBLOB,
TokenType.BINARY,
TokenType.VARBINARY,
TokenType.JSON,
@@ -176,6 +179,7 @@ class Parser(metaclass=_Parser):
TokenType.DIV,
TokenType.DISTKEY,
TokenType.DISTSTYLE,
+ TokenType.END,
TokenType.EXECUTE,
TokenType.ENGINE,
TokenType.ESCAPE,
@@ -468,9 +472,6 @@ class Parser(metaclass=_Parser):
TokenType.NULL: lambda self, _: self.expression(exp.Null),
TokenType.TRUE: lambda self, _: self.expression(exp.Boolean, this=True),
TokenType.FALSE: lambda self, _: self.expression(exp.Boolean, this=False),
- TokenType.PARAMETER: lambda self, _: self.expression(
- exp.Parameter, this=self._parse_var() or self._parse_primary()
- ),
TokenType.BIT_STRING: lambda self, token: self.expression(exp.BitString, this=token.text),
TokenType.HEX_STRING: lambda self, token: self.expression(exp.HexString, this=token.text),
TokenType.BYTE_STRING: lambda self, token: self.expression(exp.ByteString, this=token.text),
@@ -479,6 +480,16 @@ class Parser(metaclass=_Parser):
TokenType.SESSION_PARAMETER: lambda self, _: self._parse_session_parameter(),
}
+ PLACEHOLDER_PARSERS = {
+ TokenType.PLACEHOLDER: lambda self: self.expression(exp.Placeholder),
+ TokenType.PARAMETER: lambda self: self.expression(
+ exp.Parameter, this=self._parse_var() or self._parse_primary()
+ ),
+ TokenType.COLON: lambda self: self.expression(exp.Placeholder, this=self._prev.text)
+ if self._match_set((TokenType.NUMBER, TokenType.VAR))
+ else None,
+ }
+
RANGE_PARSERS = {
TokenType.BETWEEN: lambda self, this: self._parse_between(this),
TokenType.IN: lambda self, this: self._parse_in(this),
@@ -601,8 +612,7 @@ class Parser(metaclass=_Parser):
WINDOW_ALIAS_TOKENS = ID_VAR_TOKENS - {TokenType.ROWS}
- # allows tables to have special tokens as prefixes
- TABLE_PREFIX_TOKENS: t.Set[TokenType] = set()
+ ADD_CONSTRAINT_TOKENS = {TokenType.CONSTRAINT, TokenType.PRIMARY_KEY, TokenType.FOREIGN_KEY}
STRICT_CAST = True
@@ -677,7 +687,7 @@ class Parser(metaclass=_Parser):
def parse_into(
self,
- expression_types: str | exp.Expression | t.Collection[exp.Expression | str],
+ expression_types: exp.IntoType,
raw_tokens: t.List[Token],
sql: t.Optional[str] = None,
) -> t.List[t.Optional[exp.Expression]]:
@@ -820,24 +830,8 @@ class Parser(metaclass=_Parser):
if self.error_level == ErrorLevel.IGNORE:
return
- for k in expression.args:
- if k not in expression.arg_types:
- self.raise_error(f"Unexpected keyword: '{k}' for {expression.__class__}")
- for k, mandatory in expression.arg_types.items():
- v = expression.args.get(k)
- if mandatory and (v is None or (isinstance(v, list) and not v)):
- self.raise_error(f"Required keyword: '{k}' missing for {expression.__class__}")
-
- if (
- args
- and isinstance(expression, exp.Func)
- and len(args) > len(expression.arg_types)
- and not expression.is_var_len_args
- ):
- self.raise_error(
- f"The number of provided arguments ({len(args)}) is greater than "
- f"the maximum number of supported arguments ({len(expression.arg_types)})"
- )
+ for error_message in expression.error_messages(args):
+ self.raise_error(error_message)
def _find_token(self, token: Token, sql: str) -> int:
line = 1
@@ -868,6 +862,9 @@ class Parser(metaclass=_Parser):
def _retreat(self, index: int) -> None:
self._advance(index - self._index)
+ def _parse_command(self) -> exp.Expression:
+ return self.expression(exp.Command, this=self._prev.text, expression=self._parse_string())
+
def _parse_statement(self) -> t.Optional[exp.Expression]:
if self._curr is None:
return None
@@ -876,11 +873,7 @@ class Parser(metaclass=_Parser):
return self.STATEMENT_PARSERS[self._prev.token_type](self)
if self._match_set(Tokenizer.COMMANDS):
- return self.expression(
- exp.Command,
- this=self._prev.text,
- expression=self._parse_string(),
- )
+ return self._parse_command()
expression = self._parse_expression()
expression = self._parse_set_operations(expression) if expression else self._parse_select()
@@ -942,12 +935,18 @@ class Parser(metaclass=_Parser):
no_primary_index = None
indexes = None
no_schema_binding = None
+ begin = None
if create_token.token_type in (TokenType.FUNCTION, TokenType.PROCEDURE):
- this = self._parse_user_defined_function()
+ this = self._parse_user_defined_function(kind=create_token.token_type)
properties = self._parse_properties()
if self._match(TokenType.ALIAS):
- expression = self._parse_select_or_expression()
+ begin = self._match(TokenType.BEGIN)
+ return_ = self._match_text_seq("RETURN")
+ expression = self._parse_statement()
+
+ if return_:
+ expression = self.expression(exp.Return, this=expression)
elif create_token.token_type == TokenType.INDEX:
this = self._parse_index()
elif create_token.token_type in (
@@ -1002,6 +1001,7 @@ class Parser(metaclass=_Parser):
no_primary_index=no_primary_index,
indexes=indexes,
no_schema_binding=no_schema_binding,
+ begin=begin,
)
def _parse_property(self) -> t.Optional[exp.Expression]:
@@ -1087,7 +1087,7 @@ class Parser(metaclass=_Parser):
if not self._match(TokenType.GT):
self.raise_error("Expecting >")
else:
- value = self._parse_schema(exp.Literal.string("TABLE"))
+ value = self._parse_schema(exp.Var(this="TABLE"))
else:
value = self._parse_types()
@@ -1550,7 +1550,7 @@ class Parser(metaclass=_Parser):
return None
index = self._parse_id_var()
columns = None
- if self._curr and self._curr.token_type == TokenType.L_PAREN:
+ if self._match(TokenType.L_PAREN, advance=False):
columns = self._parse_wrapped_csv(self._parse_column)
return self.expression(
exp.Index,
@@ -1561,6 +1561,27 @@ class Parser(metaclass=_Parser):
amp=amp,
)
+ def _parse_table_parts(self, schema: bool = False) -> exp.Expression:
+ catalog = None
+ db = None
+ table = (not schema and self._parse_function()) or self._parse_id_var(any_token=False)
+
+ while self._match(TokenType.DOT):
+ if catalog:
+ # This allows nesting the table in arbitrarily many dot expressions if needed
+ table = self.expression(exp.Dot, this=table, expression=self._parse_id_var())
+ else:
+ catalog = db
+ db = table
+ table = self._parse_id_var()
+
+ if not table:
+ self.raise_error(f"Expected table name but got {self._curr}")
+
+ return self.expression(
+ exp.Table, this=table, db=db, catalog=catalog, pivots=self._parse_pivots()
+ )
+
def _parse_table(
self, schema: bool = False, alias_tokens: t.Optional[t.Collection[TokenType]] = None
) -> t.Optional[exp.Expression]:
@@ -1584,27 +1605,7 @@ class Parser(metaclass=_Parser):
if subquery:
return subquery
- catalog = None
- db = None
- table = (not schema and self._parse_function()) or self._parse_id_var(
- any_token=False, prefix_tokens=self.TABLE_PREFIX_TOKENS
- )
-
- while self._match(TokenType.DOT):
- if catalog:
- # This allows nesting the table in arbitrarily many dot expressions if needed
- table = self.expression(exp.Dot, this=table, expression=self._parse_id_var())
- else:
- catalog = db
- db = table
- table = self._parse_id_var()
-
- if not table:
- self.raise_error(f"Expected table name but got {self._curr}")
-
- this = self.expression(
- exp.Table, this=table, db=db, catalog=catalog, pivots=self._parse_pivots()
- )
+ this = self._parse_table_parts(schema=schema)
if schema:
return self._parse_schema(this=this)
@@ -1889,7 +1890,7 @@ class Parser(metaclass=_Parser):
expression,
this=this,
distinct=self._match(TokenType.DISTINCT) or not self._match(TokenType.ALL),
- expression=self._parse_select(nested=True),
+ expression=self._parse_set_operations(self._parse_select(nested=True)),
)
def _parse_expression(self) -> t.Optional[exp.Expression]:
@@ -2286,7 +2287,9 @@ class Parser(metaclass=_Parser):
self._match_r_paren(this)
return self._parse_window(this)
- def _parse_user_defined_function(self) -> t.Optional[exp.Expression]:
+ def _parse_user_defined_function(
+ self, kind: t.Optional[TokenType] = None
+ ) -> t.Optional[exp.Expression]:
this = self._parse_id_var()
while self._match(TokenType.DOT):
@@ -2297,7 +2300,9 @@ class Parser(metaclass=_Parser):
expressions = self._parse_csv(self._parse_udf_kwarg)
self._match_r_paren()
- return self.expression(exp.UserDefinedFunction, this=this, expressions=expressions)
+ return self.expression(
+ exp.UserDefinedFunction, this=this, expressions=expressions, wrapped=True
+ )
def _parse_introducer(self, token: Token) -> t.Optional[exp.Expression]:
literal = self._parse_primary()
@@ -2371,10 +2376,6 @@ class Parser(metaclass=_Parser):
or self._parse_column_def(self._parse_field(any_token=True))
)
self._match_r_paren()
-
- if isinstance(this, exp.Literal):
- this = this.name
-
return self.expression(exp.Schema, this=this, expressions=args)
def _parse_column_def(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
@@ -2470,15 +2471,43 @@ class Parser(metaclass=_Parser):
def _parse_unique(self) -> exp.Expression:
return self.expression(exp.Unique, expressions=self._parse_wrapped_id_vars())
+ def _parse_key_constraint_options(self) -> t.List[str]:
+ options = []
+ while True:
+ if not self._curr:
+ break
+
+ if self._match_text_seq("NOT", "ENFORCED"):
+ options.append("NOT ENFORCED")
+ elif self._match_text_seq("DEFERRABLE"):
+ options.append("DEFERRABLE")
+ elif self._match_text_seq("INITIALLY", "DEFERRED"):
+ options.append("INITIALLY DEFERRED")
+ elif self._match_text_seq("NORELY"):
+ options.append("NORELY")
+ elif self._match_text_seq("MATCH", "FULL"):
+ options.append("MATCH FULL")
+ elif self._match_text_seq("ON", "UPDATE", "NO ACTION"):
+ options.append("ON UPDATE NO ACTION")
+ elif self._match_text_seq("ON", "DELETE", "NO ACTION"):
+ options.append("ON DELETE NO ACTION")
+ else:
+ break
+
+ return options
+
def _parse_references(self) -> t.Optional[exp.Expression]:
if not self._match(TokenType.REFERENCES):
return None
- return self.expression(
- exp.Reference,
- this=self._parse_id_var(),
- expressions=self._parse_wrapped_id_vars(),
- )
+ expressions = None
+ this = self._parse_id_var()
+
+ if self._match(TokenType.L_PAREN, advance=False):
+ expressions = self._parse_wrapped_id_vars()
+
+ options = self._parse_key_constraint_options()
+ return self.expression(exp.Reference, this=this, expressions=expressions, options=options)
def _parse_foreign_key(self) -> exp.Expression:
expressions = self._parse_wrapped_id_vars()
@@ -2503,12 +2532,14 @@ class Parser(metaclass=_Parser):
options[kind] = action
return self.expression(
- exp.ForeignKey,
- expressions=expressions,
- reference=reference,
- **options, # type: ignore
+ exp.ForeignKey, expressions=expressions, reference=reference, **options # type: ignore
)
+ def _parse_primary_key(self) -> exp.Expression:
+ expressions = self._parse_wrapped_id_vars()
+ options = self._parse_key_constraint_options()
+ return self.expression(exp.PrimaryKey, expressions=expressions, options=options)
+
def _parse_bracket(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
if not self._match(TokenType.L_BRACKET):
return this
@@ -2631,7 +2662,7 @@ class Parser(metaclass=_Parser):
order = self._parse_order(this=expression)
return self.expression(exp.GroupConcat, this=order, separator=seq_get(args, 1))
- def _parse_convert(self, strict: bool) -> exp.Expression:
+ def _parse_convert(self, strict: bool) -> t.Optional[exp.Expression]:
to: t.Optional[exp.Expression]
this = self._parse_column()
@@ -2641,19 +2672,25 @@ class Parser(metaclass=_Parser):
to = self._parse_types()
else:
to = None
+
return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to)
- def _parse_position(self) -> exp.Expression:
+ def _parse_position(self, haystack_first: bool = False) -> exp.Expression:
args = self._parse_csv(self._parse_bitwise)
if self._match(TokenType.IN):
- args.append(self._parse_bitwise())
+ return self.expression(
+ exp.StrPosition, this=self._parse_bitwise(), substr=seq_get(args, 0)
+ )
- this = exp.StrPosition(
- this=seq_get(args, 1),
- substr=seq_get(args, 0),
- position=seq_get(args, 2),
- )
+ if haystack_first:
+ haystack = seq_get(args, 0)
+ needle = seq_get(args, 1)
+ else:
+ needle = seq_get(args, 0)
+ haystack = seq_get(args, 1)
+
+ this = exp.StrPosition(this=haystack, substr=needle, position=seq_get(args, 2))
self.validate_expression(this, args)
@@ -2894,24 +2931,26 @@ class Parser(metaclass=_Parser):
return None
def _parse_placeholder(self) -> t.Optional[exp.Expression]:
- if self._match(TokenType.PLACEHOLDER):
- return self.expression(exp.Placeholder)
- elif self._match(TokenType.COLON):
- if self._match_set((TokenType.NUMBER, TokenType.VAR)):
- return self.expression(exp.Placeholder, this=self._prev.text)
+ if self._match_set(self.PLACEHOLDER_PARSERS):
+ placeholder = self.PLACEHOLDER_PARSERS[self._prev.token_type](self)
+ if placeholder:
+ return placeholder
self._advance(-1)
return None
def _parse_except(self) -> t.Optional[t.List[t.Optional[exp.Expression]]]:
if not self._match(TokenType.EXCEPT):
return None
-
- return self._parse_wrapped_id_vars()
+ if self._match(TokenType.L_PAREN, advance=False):
+ return self._parse_wrapped_id_vars()
+ return self._parse_csv(self._parse_id_var)
def _parse_replace(self) -> t.Optional[t.List[t.Optional[exp.Expression]]]:
if not self._match(TokenType.REPLACE):
return None
- return self._parse_wrapped_csv(lambda: self._parse_alias(self._parse_expression()))
+ if self._match(TokenType.L_PAREN, advance=False):
+ return self._parse_wrapped_csv(self._parse_expression)
+ return self._parse_csv(self._parse_expression)
def _parse_csv(
self, parse_method: t.Callable, sep: TokenType = TokenType.COMMA
@@ -3021,6 +3060,28 @@ class Parser(metaclass=_Parser):
def _parse_drop_column(self) -> t.Optional[exp.Expression]:
return self._match(TokenType.DROP) and self._parse_drop(default_kind="COLUMN")
+ def _parse_add_constraint(self) -> t.Optional[exp.Expression]:
+ this = None
+ kind = self._prev.token_type
+
+ if kind == TokenType.CONSTRAINT:
+ this = self._parse_id_var()
+
+ if self._match(TokenType.CHECK):
+ expression = self._parse_wrapped(self._parse_conjunction)
+ enforced = self._match_text_seq("ENFORCED")
+
+ return self.expression(
+ exp.AddConstraint, this=this, expression=expression, enforced=enforced
+ )
+
+ if kind == TokenType.FOREIGN_KEY or self._match(TokenType.FOREIGN_KEY):
+ expression = self._parse_foreign_key()
+ elif kind == TokenType.PRIMARY_KEY or self._match(TokenType.PRIMARY_KEY):
+ expression = self._parse_primary_key()
+
+ return self.expression(exp.AddConstraint, this=this, expression=expression)
+
def _parse_alter(self) -> t.Optional[exp.Expression]:
if not self._match(TokenType.TABLE):
return None
@@ -3029,8 +3090,14 @@ class Parser(metaclass=_Parser):
this = self._parse_table(schema=True)
actions: t.Optional[exp.Expression | t.List[t.Optional[exp.Expression]]] = None
- if self._match_text_seq("ADD", advance=False):
- actions = self._parse_csv(self._parse_add_column)
+
+ index = self._index
+ if self._match_text_seq("ADD"):
+ if self._match_set(self.ADD_CONSTRAINT_TOKENS):
+ actions = self._parse_csv(self._parse_add_constraint)
+ else:
+ self._retreat(index)
+ actions = self._parse_csv(self._parse_add_column)
elif self._match_text_seq("DROP", advance=False):
actions = self._parse_csv(self._parse_drop_column)
elif self._match_text_seq("RENAME", "TO"):
@@ -3077,7 +3144,7 @@ class Parser(metaclass=_Parser):
def _parse_merge(self) -> exp.Expression:
self._match(TokenType.INTO)
- target = self._parse_table(schema=True)
+ target = self._parse_table()
self._match(TokenType.USING)
using = self._parse_table()
@@ -3146,12 +3213,13 @@ class Parser(metaclass=_Parser):
self._retreat(index)
return None
- def _match(self, token_type):
+ def _match(self, token_type, advance=True):
if not self._curr:
return None
if self._curr.token_type == token_type:
- self._advance()
+ if advance:
+ self._advance()
return True
return None