summaryrefslogtreecommitdiffstats
path: root/sqlglot/parser.py
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2023-06-16 09:41:15 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2023-06-16 09:41:15 +0000
commit358a09296d7198a4cc142f1976de8f3eb3318e58 (patch)
tree762db96c44014dc4db5e9fc7f6709c138589155e /sqlglot/parser.py
parentAdding upstream version 15.2.0. (diff)
downloadsqlglot-358a09296d7198a4cc142f1976de8f3eb3318e58.tar.xz
sqlglot-358a09296d7198a4cc142f1976de8f3eb3318e58.zip
Adding upstream version 16.2.1.upstream/16.2.1
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/parser.py')
-rw-r--r--sqlglot/parser.py682
1 files changed, 350 insertions, 332 deletions
diff --git a/sqlglot/parser.py b/sqlglot/parser.py
index 96bd6e3..d6888c7 100644
--- a/sqlglot/parser.py
+++ b/sqlglot/parser.py
@@ -6,7 +6,8 @@ from collections import defaultdict
from sqlglot import exp
from sqlglot.errors import ErrorLevel, ParseError, concat_messages, merge_errors
-from sqlglot.helper import apply_index_offset, ensure_collection, ensure_list, seq_get
+from sqlglot.helper import apply_index_offset, ensure_list, seq_get
+from sqlglot.time import format_time
from sqlglot.tokens import Token, Tokenizer, TokenType
from sqlglot.trie import in_trie, new_trie
@@ -25,13 +26,14 @@ def parse_var_map(args: t.List) -> exp.StarMap | exp.VarMap:
for i in range(0, len(args), 2):
keys.append(args[i])
values.append(args[i + 1])
+
return exp.VarMap(
keys=exp.Array(expressions=keys),
values=exp.Array(expressions=values),
)
-def parse_like(args: t.List) -> exp.Expression:
+def parse_like(args: t.List) -> exp.Escape | exp.Like:
like = exp.Like(this=seq_get(args, 1), expression=seq_get(args, 0))
return exp.Escape(this=like, expression=seq_get(args, 2)) if len(args) > 2 else like
@@ -47,33 +49,26 @@ def binary_range_parser(
class _Parser(type):
def __new__(cls, clsname, bases, attrs):
klass = super().__new__(cls, clsname, bases, attrs)
- klass._show_trie = new_trie(key.split(" ") for key in klass.SHOW_PARSERS)
- klass._set_trie = new_trie(key.split(" ") for key in klass.SET_PARSERS)
+
+ klass.SHOW_TRIE = new_trie(key.split(" ") for key in klass.SHOW_PARSERS)
+ klass.SET_TRIE = new_trie(key.split(" ") for key in klass.SET_PARSERS)
return klass
class Parser(metaclass=_Parser):
"""
- Parser consumes a list of tokens produced by the `sqlglot.tokens.Tokenizer` and produces
- a parsed syntax tree.
+ Parser consumes a list of tokens produced by the Tokenizer and produces a parsed syntax tree.
Args:
- error_level: the desired error level.
+ error_level: The desired error level.
Default: ErrorLevel.IMMEDIATE
- error_message_context: determines the amount of context to capture from a
+ error_message_context: Determines the amount of context to capture from a
query string when displaying the error message (in number of characters).
- Default: 50.
- index_offset: Index offset for arrays eg ARRAY[0] vs ARRAY[1] as the head of a list.
- Default: 0
- alias_post_tablesample: If the table alias comes after tablesample.
- Default: False
+ Default: 100
max_errors: Maximum number of error messages to include in a raised ParseError.
This is only relevant if error_level is ErrorLevel.RAISE.
Default: 3
- null_ordering: Indicates the default null ordering method to use if not explicitly set.
- Options are "nulls_are_small", "nulls_are_large", "nulls_are_last".
- Default: "nulls_are_small"
"""
FUNCTIONS: t.Dict[str, t.Callable] = {
@@ -83,7 +78,6 @@ class Parser(metaclass=_Parser):
to=exp.DataType(this=exp.DataType.Type.TEXT),
),
"GLOB": lambda args: exp.Glob(this=seq_get(args, 1), expression=seq_get(args, 0)),
- "IFNULL": exp.Coalesce.from_arg_list,
"LIKE": parse_like,
"TIME_TO_TIME_STR": lambda args: exp.Cast(
this=seq_get(args, 0),
@@ -108,8 +102,6 @@ class Parser(metaclass=_Parser):
TokenType.CURRENT_USER: exp.CurrentUser,
}
- JOIN_HINTS: t.Set[str] = set()
-
NESTED_TYPE_TOKENS = {
TokenType.ARRAY,
TokenType.MAP,
@@ -117,6 +109,10 @@ class Parser(metaclass=_Parser):
TokenType.STRUCT,
}
+ ENUM_TYPE_TOKENS = {
+ TokenType.ENUM,
+ }
+
TYPE_TOKENS = {
TokenType.BIT,
TokenType.BOOLEAN,
@@ -188,6 +184,7 @@ class Parser(metaclass=_Parser):
TokenType.VARIANT,
TokenType.OBJECT,
TokenType.INET,
+ TokenType.ENUM,
*NESTED_TYPE_TOKENS,
}
@@ -198,7 +195,10 @@ class Parser(metaclass=_Parser):
TokenType.SOME: exp.Any,
}
- RESERVED_KEYWORDS = {*Tokenizer.SINGLE_TOKENS.values(), TokenType.SELECT}
+ RESERVED_KEYWORDS = {
+ *Tokenizer.SINGLE_TOKENS.values(),
+ TokenType.SELECT,
+ }
DB_CREATABLES = {
TokenType.DATABASE,
@@ -216,6 +216,7 @@ class Parser(metaclass=_Parser):
*DB_CREATABLES,
}
+ # Tokens that can represent identifiers
ID_VAR_TOKENS = {
TokenType.VAR,
TokenType.ANTI,
@@ -224,6 +225,7 @@ class Parser(metaclass=_Parser):
TokenType.AUTO_INCREMENT,
TokenType.BEGIN,
TokenType.CACHE,
+ TokenType.CASE,
TokenType.COLLATE,
TokenType.COMMAND,
TokenType.COMMENT,
@@ -274,6 +276,7 @@ class Parser(metaclass=_Parser):
TokenType.TRUE,
TokenType.UNIQUE,
TokenType.UNPIVOT,
+ TokenType.UPDATE,
TokenType.VOLATILE,
TokenType.WINDOW,
*CREATABLES,
@@ -409,6 +412,8 @@ class Parser(metaclass=_Parser):
TokenType.ANTI,
}
+ JOIN_HINTS: t.Set[str] = set()
+
LAMBDAS = {
TokenType.ARROW: lambda self, expressions: self.expression(
exp.Lambda,
@@ -420,7 +425,7 @@ class Parser(metaclass=_Parser):
),
TokenType.FARROW: lambda self, expressions: self.expression(
exp.Kwarg,
- this=exp.Var(this=expressions[0].name),
+ this=exp.var(expressions[0].name),
expression=self._parse_conjunction(),
),
}
@@ -515,7 +520,7 @@ class Parser(metaclass=_Parser):
TokenType.USE: lambda self: self.expression(
exp.Use,
kind=self._match_texts(("ROLE", "WAREHOUSE", "DATABASE", "SCHEMA"))
- and exp.Var(this=self._prev.text),
+ and exp.var(self._prev.text),
this=self._parse_table(schema=False),
),
}
@@ -634,6 +639,7 @@ class Parser(metaclass=_Parser):
"TBLPROPERTIES": lambda self: self._parse_wrapped_csv(self._parse_property),
"TEMP": lambda self: self.expression(exp.TemporaryProperty),
"TEMPORARY": lambda self: self.expression(exp.TemporaryProperty),
+ "TO": lambda self: self._parse_to_table(),
"TRANSIENT": lambda self: self.expression(exp.TransientProperty),
"TTL": lambda self: self._parse_ttl(),
"USING": lambda self: self._parse_property_assignment(exp.FileFormatProperty),
@@ -710,6 +716,7 @@ class Parser(metaclass=_Parser):
FUNCTION_PARSERS: t.Dict[str, t.Callable] = {
"CAST": lambda self: self._parse_cast(self.STRICT_CAST),
+ "CONCAT": lambda self: self._parse_concat(),
"CONVERT": lambda self: self._parse_convert(self.STRICT_CAST),
"DECODE": lambda self: self._parse_decode(),
"EXTRACT": lambda self: self._parse_extract(),
@@ -755,8 +762,11 @@ class Parser(metaclass=_Parser):
MODIFIABLES = (exp.Subquery, exp.Subqueryable, exp.Table)
- TRANSACTION_KIND = {"DEFERRED", "IMMEDIATE", "EXCLUSIVE"}
+ DDL_SELECT_TOKENS = {TokenType.SELECT, TokenType.WITH, TokenType.L_PAREN}
+ PRE_VOLATILE_TOKENS = {TokenType.CREATE, TokenType.REPLACE, TokenType.UNIQUE}
+
+ TRANSACTION_KIND = {"DEFERRED", "IMMEDIATE", "EXCLUSIVE"}
TRANSACTION_CHARACTERISTICS = {
"ISOLATION LEVEL REPEATABLE READ",
"ISOLATION LEVEL READ COMMITTED",
@@ -778,6 +788,8 @@ class Parser(metaclass=_Parser):
STRICT_CAST = True
+ CONCAT_NULL_OUTPUTS_STRING = False # A NULL arg in CONCAT yields NULL by default
+
CONVERT_TYPE_FIRST = False
PREFIXED_PIVOT_COLUMNS = False
@@ -789,40 +801,39 @@ class Parser(metaclass=_Parser):
__slots__ = (
"error_level",
"error_message_context",
+ "max_errors",
"sql",
"errors",
- "index_offset",
- "unnest_column_only",
- "alias_post_tablesample",
- "max_errors",
- "null_ordering",
"_tokens",
"_index",
"_curr",
"_next",
"_prev",
"_prev_comments",
- "_show_trie",
- "_set_trie",
)
+ # Autofilled
+ INDEX_OFFSET: int = 0
+ UNNEST_COLUMN_ONLY: bool = False
+ ALIAS_POST_TABLESAMPLE: bool = False
+ STRICT_STRING_CONCAT = False
+ NULL_ORDERING: str = "nulls_are_small"
+ SHOW_TRIE: t.Dict = {}
+ SET_TRIE: t.Dict = {}
+ FORMAT_MAPPING: t.Dict[str, str] = {}
+ FORMAT_TRIE: t.Dict = {}
+ TIME_MAPPING: t.Dict[str, str] = {}
+ TIME_TRIE: t.Dict = {}
+
def __init__(
self,
error_level: t.Optional[ErrorLevel] = None,
error_message_context: int = 100,
- index_offset: int = 0,
- unnest_column_only: bool = False,
- alias_post_tablesample: bool = False,
max_errors: int = 3,
- null_ordering: t.Optional[str] = None,
):
self.error_level = error_level or ErrorLevel.IMMEDIATE
self.error_message_context = error_message_context
- self.index_offset = index_offset
- self.unnest_column_only = unnest_column_only
- self.alias_post_tablesample = alias_post_tablesample
self.max_errors = max_errors
- self.null_ordering = null_ordering
self.reset()
def reset(self):
@@ -843,11 +854,11 @@ class Parser(metaclass=_Parser):
per parsed SQL statement.
Args:
- raw_tokens: the list of tokens.
- sql: the original SQL string, used to produce helpful debug messages.
+ raw_tokens: The list of tokens.
+ sql: The original SQL string, used to produce helpful debug messages.
Returns:
- The list of syntax trees.
+ The list of the produced syntax trees.
"""
return self._parse(
parse_method=self.__class__._parse_statement, raw_tokens=raw_tokens, sql=sql
@@ -865,23 +876,25 @@ class Parser(metaclass=_Parser):
of them, stopping at the first for which the parsing succeeds.
Args:
- expression_types: the expression type(s) to try and parse the token list into.
- raw_tokens: the list of tokens.
- sql: the original SQL string, used to produce helpful debug messages.
+ expression_types: The expression type(s) to try and parse the token list into.
+ raw_tokens: The list of tokens.
+ sql: The original SQL string, used to produce helpful debug messages.
Returns:
The target Expression.
"""
errors = []
- for expression_type in ensure_collection(expression_types):
+ for expression_type in ensure_list(expression_types):
parser = self.EXPRESSION_PARSERS.get(expression_type)
if not parser:
raise TypeError(f"No parser registered for {expression_type}")
+
try:
return self._parse(parser, raw_tokens, sql)
except ParseError as e:
e.errors[0]["into_expression"] = expression_type
errors.append(e)
+
raise ParseError(
f"Failed to parse '{sql or raw_tokens}' into {expression_types}",
errors=merge_errors(errors),
@@ -895,6 +908,7 @@ class Parser(metaclass=_Parser):
) -> t.List[t.Optional[exp.Expression]]:
self.reset()
self.sql = sql or ""
+
total = len(raw_tokens)
chunks: t.List[t.List[Token]] = [[]]
@@ -922,9 +936,7 @@ class Parser(metaclass=_Parser):
return expressions
def check_errors(self) -> None:
- """
- Logs or raises any found errors, depending on the chosen error level setting.
- """
+ """Logs or raises any found errors, depending on the chosen error level setting."""
if self.error_level == ErrorLevel.WARN:
for error in self.errors:
logger.error(str(error))
@@ -969,39 +981,38 @@ class Parser(metaclass=_Parser):
Creates a new, validated Expression.
Args:
- exp_class: the expression class to instantiate.
- comments: an optional list of comments to attach to the expression.
- kwargs: the arguments to set for the expression along with their respective values.
+ exp_class: The expression class to instantiate.
+ comments: An optional list of comments to attach to the expression.
+ kwargs: The arguments to set for the expression along with their respective values.
Returns:
The target expression.
"""
instance = exp_class(**kwargs)
instance.add_comments(comments) if comments else self._add_comments(instance)
- self.validate_expression(instance)
- return instance
+ return self.validate_expression(instance)
def _add_comments(self, expression: t.Optional[exp.Expression]) -> None:
if expression and self._prev_comments:
expression.add_comments(self._prev_comments)
self._prev_comments = None
- def validate_expression(
- self, expression: exp.Expression, args: t.Optional[t.List] = None
- ) -> None:
+ def validate_expression(self, expression: E, args: t.Optional[t.List] = None) -> E:
"""
- Validates an already instantiated expression, making sure that all its mandatory arguments
- are set.
+ Validates an Expression, making sure that all its mandatory arguments are set.
Args:
- expression: the expression to validate.
- args: an optional list of items that was used to instantiate the expression, if it's a Func.
+ expression: The expression to validate.
+ args: An optional list of items that was used to instantiate the expression, if it's a Func.
+
+ Returns:
+ The validated expression.
"""
- if self.error_level == ErrorLevel.IGNORE:
- return
+ if self.error_level != ErrorLevel.IGNORE:
+ for error_message in expression.error_messages(args):
+ self.raise_error(error_message)
- for error_message in expression.error_messages(args):
- self.raise_error(error_message)
+ return expression
def _find_sql(self, start: Token, end: Token) -> str:
return self.sql[start.start : end.end + 1]
@@ -1010,6 +1021,7 @@ class Parser(metaclass=_Parser):
self._index += times
self._curr = seq_get(self._tokens, self._index)
self._next = seq_get(self._tokens, self._index + 1)
+
if self._index > 0:
self._prev = self._tokens[self._index - 1]
self._prev_comments = self._prev.comments
@@ -1031,7 +1043,6 @@ class Parser(metaclass=_Parser):
self._match(TokenType.ON)
kind = self._match_set(self.CREATABLES) and self._prev
-
if not kind:
return self._parse_as_command(start)
@@ -1050,6 +1061,12 @@ class Parser(metaclass=_Parser):
exp.Comment, this=this, kind=kind.text, expression=self._parse_string(), exists=exists
)
+ def _parse_to_table(
+ self,
+ ) -> exp.ToTableProperty:
+ table = self._parse_table_parts(schema=True)
+ return self.expression(exp.ToTableProperty, this=table)
+
# https://clickhouse.com/docs/en/engines/table-engines/mergetree-family/mergetree#mergetree-table-ttl
def _parse_ttl(self) -> exp.Expression:
def _parse_ttl_action() -> t.Optional[exp.Expression]:
@@ -1102,10 +1119,11 @@ class Parser(metaclass=_Parser):
expression = self._parse_set_operations(expression) if expression else self._parse_select()
return self._parse_query_modifiers(expression)
- def _parse_drop(self) -> t.Optional[exp.Drop | exp.Command]:
+ def _parse_drop(self) -> exp.Drop | exp.Command:
start = self._prev
temporary = self._match(TokenType.TEMPORARY)
materialized = self._match_text_seq("MATERIALIZED")
+
kind = self._match_set(self.CREATABLES) and self._prev.text
if not kind:
return self._parse_as_command(start)
@@ -1129,21 +1147,23 @@ class Parser(metaclass=_Parser):
and self._match(TokenType.EXISTS)
)
- def _parse_create(self) -> t.Optional[exp.Expression]:
+ def _parse_create(self) -> exp.Create | exp.Command:
+ # Note: this can't be None because we've matched a statement parser
start = self._prev
- replace = self._prev.text.upper() == "REPLACE" or self._match_pair(
+ replace = start.text.upper() == "REPLACE" or self._match_pair(
TokenType.OR, TokenType.REPLACE
)
unique = self._match(TokenType.UNIQUE)
if self._match_pair(TokenType.TABLE, TokenType.FUNCTION, advance=False):
- self._match(TokenType.TABLE)
+ self._advance()
properties = None
create_token = self._match_set(self.CREATABLES) and self._prev
if not create_token:
- properties = self._parse_properties() # exp.Properties.Location.POST_CREATE
+ # exp.Properties.Location.POST_CREATE
+ properties = self._parse_properties()
create_token = self._match_set(self.CREATABLES) and self._prev
if not properties or not create_token:
@@ -1157,7 +1177,7 @@ class Parser(metaclass=_Parser):
begin = None
clone = None
- def extend_props(temp_props: t.Optional[exp.Expression]) -> None:
+ def extend_props(temp_props: t.Optional[exp.Properties]) -> None:
nonlocal properties
if properties and temp_props:
properties.expressions.extend(temp_props.expressions)
@@ -1166,6 +1186,8 @@ class Parser(metaclass=_Parser):
if create_token.token_type in (TokenType.FUNCTION, TokenType.PROCEDURE):
this = self._parse_user_defined_function(kind=create_token.token_type)
+
+ # exp.Properties.Location.POST_SCHEMA ("schema" here is the UDF's type signature)
extend_props(self._parse_properties())
self._match(TokenType.ALIAS)
@@ -1190,13 +1212,8 @@ class Parser(metaclass=_Parser):
extend_props(self._parse_properties())
self._match(TokenType.ALIAS)
-
- # exp.Properties.Location.POST_ALIAS
- if not (
- self._match(TokenType.SELECT, advance=False)
- or self._match(TokenType.WITH, advance=False)
- or self._match(TokenType.L_PAREN, advance=False)
- ):
+ if not self._match_set(self.DDL_SELECT_TOKENS, advance=False):
+ # exp.Properties.Location.POST_ALIAS
extend_props(self._parse_properties())
expression = self._parse_ddl_select()
@@ -1206,7 +1223,7 @@ class Parser(metaclass=_Parser):
while True:
index = self._parse_index()
- # exp.Properties.Location.POST_EXPRESSION or exp.Properties.Location.POST_INDEX
+ # exp.Properties.Location.POST_EXPRESSION and POST_INDEX
extend_props(self._parse_properties())
if not index:
@@ -1296,7 +1313,7 @@ class Parser(metaclass=_Parser):
return None
- def _parse_stored(self) -> exp.Expression:
+ def _parse_stored(self) -> exp.FileFormatProperty:
self._match(TokenType.ALIAS)
input_format = self._parse_string() if self._match_text_seq("INPUTFORMAT") else None
@@ -1311,14 +1328,13 @@ class Parser(metaclass=_Parser):
else self._parse_var_or_string() or self._parse_number() or self._parse_id_var(),
)
- def _parse_property_assignment(self, exp_class: t.Type[exp.Expression]) -> exp.Expression:
+ def _parse_property_assignment(self, exp_class: t.Type[E]) -> E:
self._match(TokenType.EQ)
self._match(TokenType.ALIAS)
return self.expression(exp_class, this=self._parse_field())
- def _parse_properties(self, before: t.Optional[bool] = None) -> t.Optional[exp.Expression]:
+ def _parse_properties(self, before: t.Optional[bool] = None) -> t.Optional[exp.Properties]:
properties = []
-
while True:
if before:
prop = self._parse_property_before()
@@ -1335,29 +1351,25 @@ class Parser(metaclass=_Parser):
return None
- def _parse_fallback(self, no: bool = False) -> exp.Expression:
+ def _parse_fallback(self, no: bool = False) -> exp.FallbackProperty:
return self.expression(
exp.FallbackProperty, no=no, protection=self._match_text_seq("PROTECTION")
)
- def _parse_volatile_property(self) -> exp.Expression:
+ def _parse_volatile_property(self) -> exp.VolatileProperty | exp.StabilityProperty:
if self._index >= 2:
pre_volatile_token = self._tokens[self._index - 2]
else:
pre_volatile_token = None
- if pre_volatile_token and pre_volatile_token.token_type in (
- TokenType.CREATE,
- TokenType.REPLACE,
- TokenType.UNIQUE,
- ):
+ if pre_volatile_token and pre_volatile_token.token_type in self.PRE_VOLATILE_TOKENS:
return exp.VolatileProperty()
return self.expression(exp.StabilityProperty, this=exp.Literal.string("VOLATILE"))
def _parse_with_property(
self,
- ) -> t.Union[t.Optional[exp.Expression], t.List[t.Optional[exp.Expression]]]:
+ ) -> t.Optional[exp.Expression] | t.List[t.Optional[exp.Expression]]:
self._match(TokenType.WITH)
if self._match(TokenType.L_PAREN, advance=False):
return self._parse_wrapped_csv(self._parse_property)
@@ -1376,7 +1388,7 @@ class Parser(metaclass=_Parser):
return self._parse_withisolatedloading()
# https://dev.mysql.com/doc/refman/8.0/en/create-view.html
- def _parse_definer(self) -> t.Optional[exp.Expression]:
+ def _parse_definer(self) -> t.Optional[exp.DefinerProperty]:
self._match(TokenType.EQ)
user = self._parse_id_var()
@@ -1388,18 +1400,18 @@ class Parser(metaclass=_Parser):
return exp.DefinerProperty(this=f"{user}@{host}")
- def _parse_withjournaltable(self) -> exp.Expression:
+ def _parse_withjournaltable(self) -> exp.WithJournalTableProperty:
self._match(TokenType.TABLE)
self._match(TokenType.EQ)
return self.expression(exp.WithJournalTableProperty, this=self._parse_table_parts())
- def _parse_log(self, no: bool = False) -> exp.Expression:
+ def _parse_log(self, no: bool = False) -> exp.LogProperty:
return self.expression(exp.LogProperty, no=no)
- def _parse_journal(self, **kwargs) -> exp.Expression:
+ def _parse_journal(self, **kwargs) -> exp.JournalProperty:
return self.expression(exp.JournalProperty, **kwargs)
- def _parse_checksum(self) -> exp.Expression:
+ def _parse_checksum(self) -> exp.ChecksumProperty:
self._match(TokenType.EQ)
on = None
@@ -1407,53 +1419,47 @@ class Parser(metaclass=_Parser):
on = True
elif self._match_text_seq("OFF"):
on = False
- default = self._match(TokenType.DEFAULT)
- return self.expression(
- exp.ChecksumProperty,
- on=on,
- default=default,
- )
+ return self.expression(exp.ChecksumProperty, on=on, default=self._match(TokenType.DEFAULT))
- def _parse_cluster(self) -> t.Optional[exp.Expression]:
+ def _parse_cluster(self) -> t.Optional[exp.Cluster]:
if not self._match_text_seq("BY"):
self._retreat(self._index - 1)
return None
- return self.expression(
- exp.Cluster,
- expressions=self._parse_csv(self._parse_ordered),
- )
- def _parse_freespace(self) -> exp.Expression:
+ return self.expression(exp.Cluster, expressions=self._parse_csv(self._parse_ordered))
+
+ def _parse_freespace(self) -> exp.FreespaceProperty:
self._match(TokenType.EQ)
return self.expression(
exp.FreespaceProperty, this=self._parse_number(), percent=self._match(TokenType.PERCENT)
)
- def _parse_mergeblockratio(self, no: bool = False, default: bool = False) -> exp.Expression:
+ def _parse_mergeblockratio(
+ self, no: bool = False, default: bool = False
+ ) -> exp.MergeBlockRatioProperty:
if self._match(TokenType.EQ):
return self.expression(
exp.MergeBlockRatioProperty,
this=self._parse_number(),
percent=self._match(TokenType.PERCENT),
)
- return self.expression(
- exp.MergeBlockRatioProperty,
- no=no,
- default=default,
- )
+
+ return self.expression(exp.MergeBlockRatioProperty, no=no, default=default)
def _parse_datablocksize(
self,
default: t.Optional[bool] = None,
minimum: t.Optional[bool] = None,
maximum: t.Optional[bool] = None,
- ) -> exp.Expression:
+ ) -> exp.DataBlocksizeProperty:
self._match(TokenType.EQ)
size = self._parse_number()
+
units = None
if self._match_texts(("BYTES", "KBYTES", "KILOBYTES")):
units = self._prev.text
+
return self.expression(
exp.DataBlocksizeProperty,
size=size,
@@ -1463,12 +1469,13 @@ class Parser(metaclass=_Parser):
maximum=maximum,
)
- def _parse_blockcompression(self) -> exp.Expression:
+ def _parse_blockcompression(self) -> exp.BlockCompressionProperty:
self._match(TokenType.EQ)
always = self._match_text_seq("ALWAYS")
manual = self._match_text_seq("MANUAL")
never = self._match_text_seq("NEVER")
default = self._match_text_seq("DEFAULT")
+
autotemp = None
if self._match_text_seq("AUTOTEMP"):
autotemp = self._parse_schema()
@@ -1482,7 +1489,7 @@ class Parser(metaclass=_Parser):
autotemp=autotemp,
)
- def _parse_withisolatedloading(self) -> exp.Expression:
+ def _parse_withisolatedloading(self) -> exp.IsolatedLoadingProperty:
no = self._match_text_seq("NO")
concurrent = self._match_text_seq("CONCURRENT")
self._match_text_seq("ISOLATED", "LOADING")
@@ -1498,7 +1505,7 @@ class Parser(metaclass=_Parser):
for_none=for_none,
)
- def _parse_locking(self) -> exp.Expression:
+ def _parse_locking(self) -> exp.LockingProperty:
if self._match(TokenType.TABLE):
kind = "TABLE"
elif self._match(TokenType.VIEW):
@@ -1553,14 +1560,14 @@ class Parser(metaclass=_Parser):
return self._parse_csv(self._parse_conjunction)
return []
- def _parse_partitioned_by(self) -> exp.Expression:
+ def _parse_partitioned_by(self) -> exp.PartitionedByProperty:
self._match(TokenType.EQ)
return self.expression(
exp.PartitionedByProperty,
this=self._parse_schema() or self._parse_bracket(self._parse_field()),
)
- def _parse_withdata(self, no: bool = False) -> exp.Expression:
+ def _parse_withdata(self, no: bool = False) -> exp.WithDataProperty:
if self._match_text_seq("AND", "STATISTICS"):
statistics = True
elif self._match_text_seq("AND", "NO", "STATISTICS"):
@@ -1570,52 +1577,50 @@ class Parser(metaclass=_Parser):
return self.expression(exp.WithDataProperty, no=no, statistics=statistics)
- def _parse_no_property(self) -> t.Optional[exp.Property]:
+ def _parse_no_property(self) -> t.Optional[exp.NoPrimaryIndexProperty]:
if self._match_text_seq("PRIMARY", "INDEX"):
return exp.NoPrimaryIndexProperty()
return None
- def _parse_on_property(self) -> t.Optional[exp.Property]:
+ def _parse_on_property(self) -> t.Optional[exp.Expression]:
if self._match_text_seq("COMMIT", "PRESERVE", "ROWS"):
return exp.OnCommitProperty()
elif self._match_text_seq("COMMIT", "DELETE", "ROWS"):
return exp.OnCommitProperty(delete=True)
return None
- def _parse_distkey(self) -> exp.Expression:
+ def _parse_distkey(self) -> exp.DistKeyProperty:
return self.expression(exp.DistKeyProperty, this=self._parse_wrapped(self._parse_id_var))
- def _parse_create_like(self) -> t.Optional[exp.Expression]:
+ def _parse_create_like(self) -> t.Optional[exp.LikeProperty]:
table = self._parse_table(schema=True)
+
options = []
while self._match_texts(("INCLUDING", "EXCLUDING")):
this = self._prev.text.upper()
- id_var = self._parse_id_var()
+ id_var = self._parse_id_var()
if not id_var:
return None
options.append(
- self.expression(
- exp.Property,
- this=this,
- value=exp.Var(this=id_var.this.upper()),
- )
+ self.expression(exp.Property, this=this, value=exp.var(id_var.this.upper()))
)
+
return self.expression(exp.LikeProperty, this=table, expressions=options)
- def _parse_sortkey(self, compound: bool = False) -> exp.Expression:
+ def _parse_sortkey(self, compound: bool = False) -> exp.SortKeyProperty:
return self.expression(
- exp.SortKeyProperty, this=self._parse_wrapped_csv(self._parse_id_var), compound=compound
+ exp.SortKeyProperty, this=self._parse_wrapped_id_vars(), compound=compound
)
- def _parse_character_set(self, default: bool = False) -> exp.Expression:
+ def _parse_character_set(self, default: bool = False) -> exp.CharacterSetProperty:
self._match(TokenType.EQ)
return self.expression(
exp.CharacterSetProperty, this=self._parse_var_or_string(), default=default
)
- def _parse_returns(self) -> exp.Expression:
+ def _parse_returns(self) -> exp.ReturnsProperty:
value: t.Optional[exp.Expression]
is_table = self._match(TokenType.TABLE)
@@ -1629,19 +1634,18 @@ class Parser(metaclass=_Parser):
if not self._match(TokenType.GT):
self.raise_error("Expecting >")
else:
- value = self._parse_schema(exp.Var(this="TABLE"))
+ value = self._parse_schema(exp.var("TABLE"))
else:
value = self._parse_types()
return self.expression(exp.ReturnsProperty, this=value, is_table=is_table)
- def _parse_describe(self) -> exp.Expression:
+ def _parse_describe(self) -> exp.Describe:
kind = self._match_set(self.CREATABLES) and self._prev.text
this = self._parse_table()
-
return self.expression(exp.Describe, this=this, kind=kind)
- def _parse_insert(self) -> exp.Expression:
+ def _parse_insert(self) -> exp.Insert:
overwrite = self._match(TokenType.OVERWRITE)
local = self._match_text_seq("LOCAL")
alternative = None
@@ -1673,11 +1677,11 @@ class Parser(metaclass=_Parser):
alternative=alternative,
)
- def _parse_on_conflict(self) -> t.Optional[exp.Expression]:
+ def _parse_on_conflict(self) -> t.Optional[exp.OnConflict]:
conflict = self._match_text_seq("ON", "CONFLICT")
duplicate = self._match_text_seq("ON", "DUPLICATE", "KEY")
- if not (conflict or duplicate):
+ if not conflict and not duplicate:
return None
nothing = None
@@ -1707,18 +1711,20 @@ class Parser(metaclass=_Parser):
constraint=constraint,
)
- def _parse_returning(self) -> t.Optional[exp.Expression]:
+ def _parse_returning(self) -> t.Optional[exp.Returning]:
if not self._match(TokenType.RETURNING):
return None
return self.expression(exp.Returning, expressions=self._parse_csv(self._parse_column))
- def _parse_row(self) -> t.Optional[exp.Expression]:
+ def _parse_row(self) -> t.Optional[exp.RowFormatSerdeProperty | exp.RowFormatDelimitedProperty]:
if not self._match(TokenType.FORMAT):
return None
return self._parse_row_format()
- def _parse_row_format(self, match_row: bool = False) -> t.Optional[exp.Expression]:
+ def _parse_row_format(
+ self, match_row: bool = False
+ ) -> t.Optional[exp.RowFormatSerdeProperty | exp.RowFormatDelimitedProperty]:
if match_row and not self._match_pair(TokenType.ROW, TokenType.FORMAT):
return None
@@ -1744,7 +1750,7 @@ class Parser(metaclass=_Parser):
return self.expression(exp.RowFormatDelimitedProperty, **kwargs) # type: ignore
- def _parse_load(self) -> exp.Expression:
+ def _parse_load(self) -> exp.LoadData | exp.Command:
if self._match_text_seq("DATA"):
local = self._match_text_seq("LOCAL")
self._match_text_seq("INPATH")
@@ -1764,7 +1770,7 @@ class Parser(metaclass=_Parser):
)
return self._parse_as_command(self._prev)
- def _parse_delete(self) -> exp.Expression:
+ def _parse_delete(self) -> exp.Delete:
self._match(TokenType.FROM)
return self.expression(
@@ -1775,7 +1781,7 @@ class Parser(metaclass=_Parser):
returning=self._parse_returning(),
)
- def _parse_update(self) -> exp.Expression:
+ def _parse_update(self) -> exp.Update:
return self.expression(
exp.Update,
**{ # type: ignore
@@ -1787,22 +1793,20 @@ class Parser(metaclass=_Parser):
},
)
- def _parse_uncache(self) -> exp.Expression:
+ def _parse_uncache(self) -> exp.Uncache:
if not self._match(TokenType.TABLE):
self.raise_error("Expecting TABLE after UNCACHE")
return self.expression(
- exp.Uncache,
- exists=self._parse_exists(),
- this=self._parse_table(schema=True),
+ exp.Uncache, exists=self._parse_exists(), this=self._parse_table(schema=True)
)
- def _parse_cache(self) -> exp.Expression:
+ def _parse_cache(self) -> exp.Cache:
lazy = self._match_text_seq("LAZY")
self._match(TokenType.TABLE)
table = self._parse_table(schema=True)
- options = []
+ options = []
if self._match_text_seq("OPTIONS"):
self._match_l_paren()
k = self._parse_string()
@@ -1820,7 +1824,7 @@ class Parser(metaclass=_Parser):
expression=self._parse_select(nested=True),
)
- def _parse_partition(self) -> t.Optional[exp.Expression]:
+ def _parse_partition(self) -> t.Optional[exp.Partition]:
if not self._match(TokenType.PARTITION):
return None
@@ -1828,7 +1832,7 @@ class Parser(metaclass=_Parser):
exp.Partition, expressions=self._parse_wrapped_csv(self._parse_conjunction)
)
- def _parse_value(self) -> exp.Expression:
+ def _parse_value(self) -> exp.Tuple:
if self._match(TokenType.L_PAREN):
expressions = self._parse_csv(self._parse_conjunction)
self._match_r_paren()
@@ -1926,7 +1930,7 @@ class Parser(metaclass=_Parser):
return self._parse_set_operations(this)
- def _parse_with(self, skip_with_token: bool = False) -> t.Optional[exp.Expression]:
+ def _parse_with(self, skip_with_token: bool = False) -> t.Optional[exp.With]:
if not skip_with_token and not self._match(TokenType.WITH):
return None
@@ -1946,22 +1950,19 @@ class Parser(metaclass=_Parser):
exp.With, comments=comments, expressions=expressions, recursive=recursive
)
- def _parse_cte(self) -> exp.Expression:
+ def _parse_cte(self) -> exp.CTE:
alias = self._parse_table_alias()
if not alias or not alias.this:
self.raise_error("Expected CTE to have alias")
self._match(TokenType.ALIAS)
-
return self.expression(
- exp.CTE,
- this=self._parse_wrapped(self._parse_statement),
- alias=alias,
+ exp.CTE, this=self._parse_wrapped(self._parse_statement), alias=alias
)
def _parse_table_alias(
self, alias_tokens: t.Optional[t.Collection[TokenType]] = None
- ) -> t.Optional[exp.Expression]:
+ ) -> t.Optional[exp.TableAlias]:
any_token = self._match(TokenType.ALIAS)
alias = (
self._parse_id_var(any_token=any_token, tokens=alias_tokens or self.TABLE_ALIAS_TOKENS)
@@ -1982,9 +1983,10 @@ class Parser(metaclass=_Parser):
def _parse_subquery(
self, this: t.Optional[exp.Expression], parse_alias: bool = True
- ) -> t.Optional[exp.Expression]:
+ ) -> t.Optional[exp.Subquery]:
if not this:
return None
+
return self.expression(
exp.Subquery,
this=this,
@@ -2000,19 +2002,25 @@ class Parser(metaclass=_Parser):
expression = parser(self)
if expression:
+ if key == "limit":
+ offset = expression.args.pop("offset", None)
+ if offset:
+ this.set("offset", exp.Offset(expression=offset))
this.set(key, expression)
return this
- def _parse_hint(self) -> t.Optional[exp.Expression]:
+ def _parse_hint(self) -> t.Optional[exp.Hint]:
if self._match(TokenType.HINT):
hints = self._parse_csv(self._parse_function)
+
if not self._match_pair(TokenType.STAR, TokenType.SLASH):
self.raise_error("Expected */ after HINT")
+
return self.expression(exp.Hint, expressions=hints)
return None
- def _parse_into(self) -> t.Optional[exp.Expression]:
+ def _parse_into(self) -> t.Optional[exp.Into]:
if not self._match(TokenType.INTO):
return None
@@ -2039,7 +2047,7 @@ class Parser(metaclass=_Parser):
this=self._parse_query_modifiers(this) if modifiers else this,
)
- def _parse_match_recognize(self) -> t.Optional[exp.Expression]:
+ def _parse_match_recognize(self) -> t.Optional[exp.MatchRecognize]:
if not self._match(TokenType.MATCH_RECOGNIZE):
return None
@@ -2052,7 +2060,7 @@ class Parser(metaclass=_Parser):
)
if self._match_text_seq("ONE", "ROW", "PER", "MATCH"):
- rows = exp.Var(this="ONE ROW PER MATCH")
+ rows = exp.var("ONE ROW PER MATCH")
elif self._match_text_seq("ALL", "ROWS", "PER", "MATCH"):
text = "ALL ROWS PER MATCH"
if self._match_text_seq("SHOW", "EMPTY", "MATCHES"):
@@ -2061,7 +2069,7 @@ class Parser(metaclass=_Parser):
text += f" OMIT EMPTY MATCHES"
elif self._match_text_seq("WITH", "UNMATCHED", "ROWS"):
text += f" WITH UNMATCHED ROWS"
- rows = exp.Var(this=text)
+ rows = exp.var(text)
else:
rows = None
@@ -2075,7 +2083,7 @@ class Parser(metaclass=_Parser):
text += f" TO FIRST {self._advance_any().text}" # type: ignore
elif self._match_text_seq("TO", "LAST"):
text += f" TO LAST {self._advance_any().text}" # type: ignore
- after = exp.Var(this=text)
+ after = exp.var(text)
else:
after = None
@@ -2093,11 +2101,14 @@ class Parser(metaclass=_Parser):
paren += 1
if self._curr.token_type == TokenType.R_PAREN:
paren -= 1
+
end = self._prev
self._advance()
+
if paren > 0:
self.raise_error("Expecting )", self._curr)
- pattern = exp.Var(this=self._find_sql(start, end))
+
+ pattern = exp.var(self._find_sql(start, end))
else:
pattern = None
@@ -2127,7 +2138,7 @@ class Parser(metaclass=_Parser):
alias=self._parse_table_alias(),
)
- def _parse_lateral(self) -> t.Optional[exp.Expression]:
+ def _parse_lateral(self) -> t.Optional[exp.Lateral]:
outer_apply = self._match_pair(TokenType.OUTER, TokenType.APPLY)
cross_apply = self._match_pair(TokenType.CROSS, TokenType.APPLY)
@@ -2150,24 +2161,19 @@ class Parser(metaclass=_Parser):
expression=self._parse_function() or self._parse_id_var(any_token=False),
)
- table_alias: t.Optional[exp.Expression]
-
if view:
table = self._parse_id_var(any_token=False)
columns = self._parse_csv(self._parse_id_var) if self._match(TokenType.ALIAS) else []
- table_alias = self.expression(exp.TableAlias, this=table, columns=columns)
+ table_alias: t.Optional[exp.TableAlias] = self.expression(
+ exp.TableAlias, this=table, columns=columns
+ )
+ elif isinstance(this, exp.Subquery) and this.alias:
+ # Ensures parity between the Subquery's and the Lateral's "alias" args
+ table_alias = this.args["alias"].copy()
else:
table_alias = self._parse_table_alias()
- expression = self.expression(
- exp.Lateral,
- this=this,
- view=view,
- outer=outer,
- alias=table_alias,
- )
-
- return expression
+ return self.expression(exp.Lateral, this=this, view=view, outer=outer, alias=table_alias)
def _parse_join_parts(
self,
@@ -2178,7 +2184,7 @@ class Parser(metaclass=_Parser):
self._match_set(self.JOIN_KINDS) and self._prev,
)
- def _parse_join(self, skip_join_token: bool = False) -> t.Optional[exp.Expression]:
+ def _parse_join(self, skip_join_token: bool = False) -> t.Optional[exp.Join]:
if self._match(TokenType.COMMA):
return self.expression(exp.Join, this=self._parse_table())
@@ -2223,7 +2229,7 @@ class Parser(metaclass=_Parser):
def _parse_index(
self,
index: t.Optional[exp.Expression] = None,
- ) -> t.Optional[exp.Expression]:
+ ) -> t.Optional[exp.Index]:
if index:
unique = None
primary = None
@@ -2236,11 +2242,15 @@ class Parser(metaclass=_Parser):
unique = self._match(TokenType.UNIQUE)
primary = self._match_text_seq("PRIMARY")
amp = self._match_text_seq("AMP")
+
if not self._match(TokenType.INDEX):
return None
+
index = self._parse_id_var()
table = None
+ using = self._parse_field() if self._match(TokenType.USING) else None
+
if self._match(TokenType.L_PAREN, advance=False):
columns = self._parse_wrapped_csv(self._parse_ordered)
else:
@@ -2250,6 +2260,7 @@ class Parser(metaclass=_Parser):
exp.Index,
this=index,
table=table,
+ using=using,
columns=columns,
unique=unique,
primary=primary,
@@ -2259,7 +2270,7 @@ class Parser(metaclass=_Parser):
def _parse_table_part(self, schema: bool = False) -> t.Optional[exp.Expression]:
return (
- (not schema and self._parse_function())
+ (not schema and self._parse_function(optional_parens=False))
or self._parse_id_var(any_token=False)
or self._parse_string_as_identifier()
or self._parse_placeholder()
@@ -2314,7 +2325,7 @@ class Parser(metaclass=_Parser):
if schema:
return self._parse_schema(this=this)
- if self.alias_post_tablesample:
+ if self.ALIAS_POST_TABLESAMPLE:
table_sample = self._parse_table_sample()
alias = self._parse_table_alias(alias_tokens=alias_tokens or self.TABLE_ALIAS_TOKENS)
@@ -2331,7 +2342,7 @@ class Parser(metaclass=_Parser):
)
self._match_r_paren()
- if not self.alias_post_tablesample:
+ if not self.ALIAS_POST_TABLESAMPLE:
table_sample = self._parse_table_sample()
if table_sample:
@@ -2340,46 +2351,47 @@ class Parser(metaclass=_Parser):
return this
- def _parse_unnest(self) -> t.Optional[exp.Expression]:
+ def _parse_unnest(self, with_alias: bool = True) -> t.Optional[exp.Unnest]:
if not self._match(TokenType.UNNEST):
return None
expressions = self._parse_wrapped_csv(self._parse_type)
ordinality = self._match_pair(TokenType.WITH, TokenType.ORDINALITY)
- alias = self._parse_table_alias()
- if alias and self.unnest_column_only:
+ alias = self._parse_table_alias() if with_alias else None
+
+ if alias and self.UNNEST_COLUMN_ONLY:
if alias.args.get("columns"):
self.raise_error("Unexpected extra column alias in unnest.")
+
alias.set("columns", [alias.this])
alias.set("this", None)
offset = None
if self._match_pair(TokenType.WITH, TokenType.OFFSET):
self._match(TokenType.ALIAS)
- offset = self._parse_id_var() or exp.Identifier(this="offset")
+ offset = self._parse_id_var() or exp.to_identifier("offset")
return self.expression(
- exp.Unnest,
- expressions=expressions,
- ordinality=ordinality,
- alias=alias,
- offset=offset,
+ exp.Unnest, expressions=expressions, ordinality=ordinality, alias=alias, offset=offset
)
- def _parse_derived_table_values(self) -> t.Optional[exp.Expression]:
+ def _parse_derived_table_values(self) -> t.Optional[exp.Values]:
is_derived = self._match_pair(TokenType.L_PAREN, TokenType.VALUES)
if not is_derived and not self._match(TokenType.VALUES):
return None
expressions = self._parse_csv(self._parse_value)
+ alias = self._parse_table_alias()
if is_derived:
self._match_r_paren()
- return self.expression(exp.Values, expressions=expressions, alias=self._parse_table_alias())
+ return self.expression(
+ exp.Values, expressions=expressions, alias=alias or self._parse_table_alias()
+ )
- def _parse_table_sample(self, as_modifier: bool = False) -> t.Optional[exp.Expression]:
+ def _parse_table_sample(self, as_modifier: bool = False) -> t.Optional[exp.TableSample]:
if not self._match(TokenType.TABLE_SAMPLE) and not (
as_modifier and self._match_text_seq("USING", "SAMPLE")
):
@@ -2456,7 +2468,7 @@ class Parser(metaclass=_Parser):
exp.Pivot, this=this, expressions=expressions, using=using, group=group
)
- def _parse_pivot(self) -> t.Optional[exp.Expression]:
+ def _parse_pivot(self) -> t.Optional[exp.Pivot]:
index = self._index
if self._match(TokenType.PIVOT):
@@ -2519,7 +2531,7 @@ class Parser(metaclass=_Parser):
def _pivot_column_names(self, aggregations: t.List[exp.Expression]) -> t.List[str]:
return [agg.alias for agg in aggregations]
- def _parse_where(self, skip_where_token: bool = False) -> t.Optional[exp.Expression]:
+ def _parse_where(self, skip_where_token: bool = False) -> t.Optional[exp.Where]:
if not skip_where_token and not self._match(TokenType.WHERE):
return None
@@ -2527,7 +2539,7 @@ class Parser(metaclass=_Parser):
exp.Where, comments=self._prev_comments, this=self._parse_conjunction()
)
- def _parse_group(self, skip_group_by_token: bool = False) -> t.Optional[exp.Expression]:
+ def _parse_group(self, skip_group_by_token: bool = False) -> t.Optional[exp.Group]:
if not skip_group_by_token and not self._match(TokenType.GROUP_BY):
return None
@@ -2578,12 +2590,12 @@ class Parser(metaclass=_Parser):
return self._parse_column()
- def _parse_having(self, skip_having_token: bool = False) -> t.Optional[exp.Expression]:
+ def _parse_having(self, skip_having_token: bool = False) -> t.Optional[exp.Having]:
if not skip_having_token and not self._match(TokenType.HAVING):
return None
return self.expression(exp.Having, this=self._parse_conjunction())
- def _parse_qualify(self) -> t.Optional[exp.Expression]:
+ def _parse_qualify(self) -> t.Optional[exp.Qualify]:
if not self._match(TokenType.QUALIFY):
return None
return self.expression(exp.Qualify, this=self._parse_conjunction())
@@ -2598,16 +2610,15 @@ class Parser(metaclass=_Parser):
exp.Order, this=this, expressions=self._parse_csv(self._parse_ordered)
)
- def _parse_sort(
- self, exp_class: t.Type[exp.Expression], *texts: str
- ) -> t.Optional[exp.Expression]:
+ def _parse_sort(self, exp_class: t.Type[E], *texts: str) -> t.Optional[E]:
if not self._match_text_seq(*texts):
return None
return self.expression(exp_class, expressions=self._parse_csv(self._parse_ordered))
- def _parse_ordered(self) -> exp.Expression:
+ def _parse_ordered(self) -> exp.Ordered:
this = self._parse_conjunction()
self._match(TokenType.ASC)
+
is_desc = self._match(TokenType.DESC)
is_nulls_first = self._match_text_seq("NULLS", "FIRST")
is_nulls_last = self._match_text_seq("NULLS", "LAST")
@@ -2615,13 +2626,14 @@ class Parser(metaclass=_Parser):
asc = not desc
nulls_first = is_nulls_first or False
explicitly_null_ordered = is_nulls_first or is_nulls_last
+
if (
not explicitly_null_ordered
and (
- (asc and self.null_ordering == "nulls_are_small")
- or (desc and self.null_ordering != "nulls_are_small")
+ (asc and self.NULL_ORDERING == "nulls_are_small")
+ or (desc and self.NULL_ORDERING != "nulls_are_small")
)
- and self.null_ordering != "nulls_are_last"
+ and self.NULL_ORDERING != "nulls_are_last"
):
nulls_first = True
@@ -2632,9 +2644,15 @@ class Parser(metaclass=_Parser):
) -> t.Optional[exp.Expression]:
if self._match(TokenType.TOP if top else TokenType.LIMIT):
limit_paren = self._match(TokenType.L_PAREN)
- limit_exp = self.expression(
- exp.Limit, this=this, expression=self._parse_number() if top else self._parse_term()
- )
+ expression = self._parse_number() if top else self._parse_term()
+
+ if self._match(TokenType.COMMA):
+ offset = expression
+ expression = self._parse_term()
+ else:
+ offset = None
+
+ limit_exp = self.expression(exp.Limit, this=this, expression=expression, offset=offset)
if limit_paren:
self._match_r_paren()
@@ -2667,17 +2685,15 @@ class Parser(metaclass=_Parser):
return this
def _parse_offset(self, this: t.Optional[exp.Expression] = None) -> t.Optional[exp.Expression]:
- if not self._match_set((TokenType.OFFSET, TokenType.COMMA)):
+ if not self._match(TokenType.OFFSET):
return this
count = self._parse_number()
self._match_set((TokenType.ROW, TokenType.ROWS))
return self.expression(exp.Offset, this=this, expression=count)
- def _parse_locks(self) -> t.List[exp.Expression]:
- # Lists are invariant, so we need to use a type hint here
- locks: t.List[exp.Expression] = []
-
+ def _parse_locks(self) -> t.List[exp.Lock]:
+ locks = []
while True:
if self._match_text_seq("FOR", "UPDATE"):
update = True
@@ -2768,6 +2784,7 @@ class Parser(metaclass=_Parser):
def _parse_is(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
index = self._index - 1
negate = self._match(TokenType.NOT)
+
if self._match_text_seq("DISTINCT", "FROM"):
klass = exp.NullSafeEQ if negate else exp.NullSafeNEQ
return self.expression(klass, this=this, expression=self._parse_expression())
@@ -2781,7 +2798,7 @@ class Parser(metaclass=_Parser):
return self.expression(exp.Not, this=this) if negate else this
def _parse_in(self, this: t.Optional[exp.Expression], alias: bool = False) -> exp.In:
- unnest = self._parse_unnest()
+ unnest = self._parse_unnest(with_alias=False)
if unnest:
this = self.expression(exp.In, this=this, unnest=unnest)
elif self._match(TokenType.L_PAREN):
@@ -2798,7 +2815,7 @@ class Parser(metaclass=_Parser):
return this
- def _parse_between(self, this: exp.Expression) -> exp.Expression:
+ def _parse_between(self, this: exp.Expression) -> exp.Between:
low = self._parse_bitwise()
self._match(TokenType.AND)
high = self._parse_bitwise()
@@ -2809,7 +2826,7 @@ class Parser(metaclass=_Parser):
return this
return self.expression(exp.Escape, this=this, expression=self._parse_string())
- def _parse_interval(self) -> t.Optional[exp.Expression]:
+ def _parse_interval(self) -> t.Optional[exp.Interval]:
if not self._match(TokenType.INTERVAL):
return None
@@ -2840,9 +2857,7 @@ class Parser(metaclass=_Parser):
while True:
if self._match_set(self.BITWISE):
this = self.expression(
- self.BITWISE[self._prev.token_type],
- this=this,
- expression=self._parse_term(),
+ self.BITWISE[self._prev.token_type], this=this, expression=self._parse_term()
)
elif self._match_pair(TokenType.LT, TokenType.LT):
this = self.expression(
@@ -2890,7 +2905,7 @@ class Parser(metaclass=_Parser):
return this
- def _parse_type_size(self) -> t.Optional[exp.Expression]:
+ def _parse_type_size(self) -> t.Optional[exp.DataTypeSize]:
this = self._parse_type()
if not this:
return None
@@ -2926,6 +2941,8 @@ class Parser(metaclass=_Parser):
expressions = self._parse_csv(
lambda: self._parse_types(check_func=check_func, schema=schema)
)
+ elif type_token in self.ENUM_TYPE_TOKENS:
+ expressions = self._parse_csv(self._parse_primary)
else:
expressions = self._parse_csv(self._parse_type_size)
@@ -2943,11 +2960,7 @@ class Parser(metaclass=_Parser):
)
while self._match_pair(TokenType.L_BRACKET, TokenType.R_BRACKET):
- this = exp.DataType(
- this=exp.DataType.Type.ARRAY,
- expressions=[this],
- nested=True,
- )
+ this = exp.DataType(this=exp.DataType.Type.ARRAY, expressions=[this], nested=True)
return this
@@ -2973,23 +2986,14 @@ class Parser(metaclass=_Parser):
value: t.Optional[exp.Expression] = None
if type_token in self.TIMESTAMPS:
- if self._match_text_seq("WITH", "TIME", "ZONE") or type_token == TokenType.TIMESTAMPTZ:
+ if self._match_text_seq("WITH", "TIME", "ZONE"):
+ maybe_func = False
value = exp.DataType(this=exp.DataType.Type.TIMESTAMPTZ, expressions=expressions)
- elif (
- self._match_text_seq("WITH", "LOCAL", "TIME", "ZONE")
- or type_token == TokenType.TIMESTAMPLTZ
- ):
+ elif self._match_text_seq("WITH", "LOCAL", "TIME", "ZONE"):
+ maybe_func = False
value = exp.DataType(this=exp.DataType.Type.TIMESTAMPLTZ, expressions=expressions)
elif self._match_text_seq("WITHOUT", "TIME", "ZONE"):
- if type_token == TokenType.TIME:
- value = exp.DataType(this=exp.DataType.Type.TIME, expressions=expressions)
- else:
- value = exp.DataType(this=exp.DataType.Type.TIMESTAMP, expressions=expressions)
-
- maybe_func = maybe_func and value is None
-
- if value is None:
- value = exp.DataType(this=exp.DataType.Type.TIMESTAMP, expressions=expressions)
+ maybe_func = False
elif type_token == TokenType.INTERVAL:
unit = self._parse_var()
@@ -3037,7 +3041,7 @@ class Parser(metaclass=_Parser):
return self._parse_bracket(this)
return self._parse_column_ops(this)
- def _parse_column_ops(self, this: exp.Expression) -> exp.Expression:
+ def _parse_column_ops(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
this = self._parse_bracket(this)
while self._match_set(self.COLUMN_OPERATORS):
@@ -3057,7 +3061,7 @@ class Parser(metaclass=_Parser):
else exp.Literal.string(value)
)
else:
- field = self._parse_field(anonymous_func=True)
+ field = self._parse_field(anonymous_func=True, any_token=True)
if isinstance(field, exp.Func):
# bigquery allows function calls like x.y.count(...)
@@ -3089,8 +3093,10 @@ class Parser(metaclass=_Parser):
expressions = [primary]
while self._match(TokenType.STRING):
expressions.append(exp.Literal.string(self._prev.text))
+
if len(expressions) > 1:
return self.expression(exp.Concat, expressions=expressions)
+
return primary
if self._match_pair(TokenType.DOT, TokenType.NUMBER):
@@ -3118,8 +3124,8 @@ class Parser(metaclass=_Parser):
if this:
this.add_comments(comments)
- self._match_r_paren(expression=this)
+ self._match_r_paren(expression=this)
return this
return None
@@ -3137,18 +3143,21 @@ class Parser(metaclass=_Parser):
)
def _parse_function(
- self, functions: t.Optional[t.Dict[str, t.Callable]] = None, anonymous: bool = False
+ self,
+ functions: t.Optional[t.Dict[str, t.Callable]] = None,
+ anonymous: bool = False,
+ optional_parens: bool = True,
) -> t.Optional[exp.Expression]:
if not self._curr:
return None
token_type = self._curr.token_type
- if self._match_set(self.NO_PAREN_FUNCTION_PARSERS):
+ if optional_parens and self._match_set(self.NO_PAREN_FUNCTION_PARSERS):
return self.NO_PAREN_FUNCTION_PARSERS[token_type](self)
if not self._next or self._next.token_type != TokenType.L_PAREN:
- if token_type in self.NO_PAREN_FUNCTIONS:
+ if optional_parens and token_type in self.NO_PAREN_FUNCTIONS:
self._advance()
return self.expression(self.NO_PAREN_FUNCTIONS[token_type])
@@ -3182,8 +3191,7 @@ class Parser(metaclass=_Parser):
args = self._parse_csv(lambda: self._parse_lambda(alias=alias))
if function and not anonymous:
- this = function(args)
- self.validate_expression(this, args)
+ this = self.validate_expression(function(args), args)
else:
this = self.expression(exp.Anonymous, this=this, expressions=args)
@@ -3210,14 +3218,14 @@ class Parser(metaclass=_Parser):
exp.UserDefinedFunction, this=this, expressions=expressions, wrapped=True
)
- def _parse_introducer(self, token: Token) -> t.Optional[exp.Expression]:
+ def _parse_introducer(self, token: Token) -> exp.Introducer | exp.Identifier:
literal = self._parse_primary()
if literal:
return self.expression(exp.Introducer, this=token.text, expression=literal)
return self.expression(exp.Identifier, this=token.text)
- def _parse_session_parameter(self) -> exp.Expression:
+ def _parse_session_parameter(self) -> exp.SessionParameter:
kind = None
this = self._parse_id_var() or self._parse_primary()
@@ -3255,7 +3263,7 @@ class Parser(metaclass=_Parser):
if isinstance(this, exp.EQ):
left = this.this
if isinstance(left, exp.Column):
- left.replace(exp.Var(this=left.text("this")))
+ left.replace(exp.var(left.text("this")))
return self._parse_limit(self._parse_order(self._parse_respect_or_ignore_nulls(this)))
@@ -3279,6 +3287,7 @@ class Parser(metaclass=_Parser):
lambda: self._parse_constraint()
or self._parse_column_def(self._parse_field(any_token=True))
)
+
self._match_r_paren()
return self.expression(exp.Schema, this=this, expressions=args)
@@ -3286,6 +3295,7 @@ class Parser(metaclass=_Parser):
# column defs are not really columns, they're identifiers
if isinstance(this, exp.Column):
this = this.this
+
kind = self._parse_types(schema=True)
if self._match_text_seq("FOR", "ORDINALITY"):
@@ -3303,7 +3313,9 @@ class Parser(metaclass=_Parser):
return self.expression(exp.ColumnDef, this=this, kind=kind, constraints=constraints)
- def _parse_auto_increment(self) -> exp.Expression:
+ def _parse_auto_increment(
+ self,
+ ) -> exp.GeneratedAsIdentityColumnConstraint | exp.AutoIncrementColumnConstraint:
start = None
increment = None
@@ -3321,7 +3333,7 @@ class Parser(metaclass=_Parser):
return exp.AutoIncrementColumnConstraint()
- def _parse_compress(self) -> exp.Expression:
+ def _parse_compress(self) -> exp.CompressColumnConstraint:
if self._match(TokenType.L_PAREN, advance=False):
return self.expression(
exp.CompressColumnConstraint, this=self._parse_wrapped_csv(self._parse_bitwise)
@@ -3329,7 +3341,7 @@ class Parser(metaclass=_Parser):
return self.expression(exp.CompressColumnConstraint, this=self._parse_bitwise())
- def _parse_generated_as_identity(self) -> exp.Expression:
+ def _parse_generated_as_identity(self) -> exp.GeneratedAsIdentityColumnConstraint:
if self._match_text_seq("BY", "DEFAULT"):
on_null = self._match_pair(TokenType.ON, TokenType.NULL)
this = self.expression(
@@ -3364,11 +3376,13 @@ class Parser(metaclass=_Parser):
return this
- def _parse_inline(self) -> t.Optional[exp.Expression]:
+ def _parse_inline(self) -> exp.InlineLengthColumnConstraint:
self._match_text_seq("LENGTH")
return self.expression(exp.InlineLengthColumnConstraint, this=self._parse_bitwise())
- def _parse_not_constraint(self) -> t.Optional[exp.Expression]:
+ def _parse_not_constraint(
+ self,
+ ) -> t.Optional[exp.NotNullColumnConstraint | exp.CaseSpecificColumnConstraint]:
if self._match_text_seq("NULL"):
return self.expression(exp.NotNullColumnConstraint)
if self._match_text_seq("CASESPECIFIC"):
@@ -3417,7 +3431,7 @@ class Parser(metaclass=_Parser):
return self.CONSTRAINT_PARSERS[constraint](self)
- def _parse_unique(self) -> exp.Expression:
+ def _parse_unique(self) -> exp.UniqueColumnConstraint:
self._match_text_seq("KEY")
return self.expression(
exp.UniqueColumnConstraint, this=self._parse_schema(self._parse_id_var(any_token=False))
@@ -3460,7 +3474,7 @@ class Parser(metaclass=_Parser):
return options
- def _parse_references(self, match: bool = True) -> t.Optional[exp.Expression]:
+ def _parse_references(self, match: bool = True) -> t.Optional[exp.Reference]:
if match and not self._match(TokenType.REFERENCES):
return None
@@ -3473,7 +3487,7 @@ class Parser(metaclass=_Parser):
options = self._parse_key_constraint_options()
return self.expression(exp.Reference, this=this, expressions=expressions, options=options)
- def _parse_foreign_key(self) -> exp.Expression:
+ def _parse_foreign_key(self) -> exp.ForeignKey:
expressions = self._parse_wrapped_id_vars()
reference = self._parse_references()
options = {}
@@ -3501,7 +3515,7 @@ class Parser(metaclass=_Parser):
def _parse_primary_key(
self, wrapped_optional: bool = False, in_props: bool = False
- ) -> exp.Expression:
+ ) -> exp.PrimaryKeyColumnConstraint | exp.PrimaryKey:
desc = (
self._match_set((TokenType.ASC, TokenType.DESC))
and self._prev.token_type == TokenType.DESC
@@ -3514,15 +3528,7 @@ class Parser(metaclass=_Parser):
options = self._parse_key_constraint_options()
return self.expression(exp.PrimaryKey, expressions=expressions, options=options)
- @t.overload
- def _parse_bracket(self, this: exp.Expression) -> exp.Expression:
- ...
-
- @t.overload
def _parse_bracket(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
- ...
-
- def _parse_bracket(self, this):
if not self._match_set((TokenType.L_BRACKET, TokenType.L_BRACE)):
return this
@@ -3541,7 +3547,7 @@ class Parser(metaclass=_Parser):
elif not this or this.name.upper() == "ARRAY":
this = self.expression(exp.Array, expressions=expressions)
else:
- expressions = apply_index_offset(this, expressions, -self.index_offset)
+ expressions = apply_index_offset(this, expressions, -self.INDEX_OFFSET)
this = self.expression(exp.Bracket, this=this, expressions=expressions)
if not self._match(TokenType.R_BRACKET) and bracket_kind == TokenType.L_BRACKET:
@@ -3582,8 +3588,7 @@ class Parser(metaclass=_Parser):
def _parse_if(self) -> t.Optional[exp.Expression]:
if self._match(TokenType.L_PAREN):
args = self._parse_csv(self._parse_conjunction)
- this = exp.If.from_arg_list(args)
- self.validate_expression(this, args)
+ this = self.validate_expression(exp.If.from_arg_list(args), args)
self._match_r_paren()
else:
index = self._index - 1
@@ -3601,7 +3606,7 @@ class Parser(metaclass=_Parser):
return self._parse_window(this)
- def _parse_extract(self) -> exp.Expression:
+ def _parse_extract(self) -> exp.Extract:
this = self._parse_function() or self._parse_var() or self._parse_type()
if self._match(TokenType.FROM):
@@ -3630,9 +3635,37 @@ class Parser(metaclass=_Parser):
elif to.this == exp.DataType.Type.CHAR:
if self._match(TokenType.CHARACTER_SET):
to = self.expression(exp.CharacterSet, this=self._parse_var_or_string())
+ elif to.this in exp.DataType.TEMPORAL_TYPES and self._match(TokenType.FORMAT):
+ fmt = self._parse_string()
+
+ return self.expression(
+ exp.StrToDate if to.this == exp.DataType.Type.DATE else exp.StrToTime,
+ this=this,
+ format=exp.Literal.string(
+ format_time(
+ fmt.this if fmt else "",
+ self.FORMAT_MAPPING or self.TIME_MAPPING,
+ self.FORMAT_TRIE or self.TIME_TRIE,
+ )
+ ),
+ )
return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to)
+ def _parse_concat(self) -> t.Optional[exp.Expression]:
+ args = self._parse_csv(self._parse_conjunction)
+ if self.CONCAT_NULL_OUTPUTS_STRING:
+ args = [exp.func("COALESCE", arg, exp.Literal.string("")) for arg in args]
+
+ # Some dialects (e.g. Trino) don't allow a single-argument CONCAT call, so when
+ # we find such a call we replace it with its argument.
+ if len(args) == 1:
+ return args[0]
+
+ return self.expression(
+ exp.Concat if self.STRICT_STRING_CONCAT else exp.SafeConcat, expressions=args
+ )
+
def _parse_string_agg(self) -> exp.Expression:
expression: t.Optional[exp.Expression]
@@ -3654,9 +3687,7 @@ class Parser(metaclass=_Parser):
# the STRING_AGG call is parsed like in MySQL / SQLite and can thus be transpiled more easily to them.
if not self._match_text_seq("WITHIN", "GROUP"):
self._retreat(index)
- this = exp.GroupConcat.from_arg_list(args)
- self.validate_expression(this, args)
- return this
+ return self.validate_expression(exp.GroupConcat.from_arg_list(args), args)
self._match_l_paren() # The corresponding match_r_paren will be called in parse_function (caller)
order = self._parse_order(this=expression)
@@ -3679,7 +3710,7 @@ class Parser(metaclass=_Parser):
return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to)
- def _parse_decode(self) -> t.Optional[exp.Expression]:
+ def _parse_decode(self) -> t.Optional[exp.Decode | exp.Case]:
"""
There are generally two variants of the DECODE function:
@@ -3726,18 +3757,20 @@ class Parser(metaclass=_Parser):
return exp.Case(ifs=ifs, default=expressions[-1] if len(expressions) % 2 == 1 else None)
- def _parse_json_key_value(self) -> t.Optional[exp.Expression]:
+ def _parse_json_key_value(self) -> t.Optional[exp.JSONKeyValue]:
self._match_text_seq("KEY")
key = self._parse_field()
self._match(TokenType.COLON)
self._match_text_seq("VALUE")
value = self._parse_field()
+
if not key and not value:
return None
return self.expression(exp.JSONKeyValue, this=key, expression=value)
- def _parse_json_object(self) -> exp.Expression:
- expressions = self._parse_csv(self._parse_json_key_value)
+ def _parse_json_object(self) -> exp.JSONObject:
+ star = self._parse_star()
+ expressions = [star] if star else self._parse_csv(self._parse_json_key_value)
null_handling = None
if self._match_text_seq("NULL", "ON", "NULL"):
@@ -3767,7 +3800,7 @@ class Parser(metaclass=_Parser):
encoding=encoding,
)
- def _parse_logarithm(self) -> exp.Expression:
+ def _parse_logarithm(self) -> exp.Func:
# Default argument order is base, expression
args = self._parse_csv(self._parse_range)
@@ -3780,7 +3813,7 @@ class Parser(metaclass=_Parser):
exp.Ln if self.LOG_DEFAULTS_TO_LN else exp.Log, this=seq_get(args, 0)
)
- def _parse_match_against(self) -> exp.Expression:
+ def _parse_match_against(self) -> exp.MatchAgainst:
expressions = self._parse_csv(self._parse_column)
self._match_text_seq(")", "AGAINST", "(")
@@ -3803,15 +3836,16 @@ class Parser(metaclass=_Parser):
)
# https://learn.microsoft.com/en-us/sql/t-sql/functions/openjson-transact-sql?view=sql-server-ver16
- def _parse_open_json(self) -> exp.Expression:
+ def _parse_open_json(self) -> exp.OpenJSON:
this = self._parse_bitwise()
path = self._match(TokenType.COMMA) and self._parse_string()
- def _parse_open_json_column_def() -> exp.Expression:
+ def _parse_open_json_column_def() -> exp.OpenJSONColumnDef:
this = self._parse_field(any_token=True)
kind = self._parse_types()
path = self._parse_string()
as_json = self._match_pair(TokenType.ALIAS, TokenType.JSON)
+
return self.expression(
exp.OpenJSONColumnDef, this=this, kind=kind, path=path, as_json=as_json
)
@@ -3823,7 +3857,7 @@ class Parser(metaclass=_Parser):
return self.expression(exp.OpenJSON, this=this, path=path, expressions=expressions)
- def _parse_position(self, haystack_first: bool = False) -> exp.Expression:
+ def _parse_position(self, haystack_first: bool = False) -> exp.StrPosition:
args = self._parse_csv(self._parse_bitwise)
if self._match(TokenType.IN):
@@ -3838,17 +3872,15 @@ class Parser(metaclass=_Parser):
needle = seq_get(args, 0)
haystack = seq_get(args, 1)
- this = exp.StrPosition(this=haystack, substr=needle, position=seq_get(args, 2))
-
- self.validate_expression(this, args)
-
- return this
+ return self.expression(
+ exp.StrPosition, this=haystack, substr=needle, position=seq_get(args, 2)
+ )
- def _parse_join_hint(self, func_name: str) -> exp.Expression:
+ def _parse_join_hint(self, func_name: str) -> exp.JoinHint:
args = self._parse_csv(self._parse_table)
return exp.JoinHint(this=func_name.upper(), expressions=args)
- def _parse_substring(self) -> exp.Expression:
+ def _parse_substring(self) -> exp.Substring:
# Postgres supports the form: substring(string [from int] [for int])
# https://www.postgresql.org/docs/9.1/functions-string.html @ Table 9-6
@@ -3859,12 +3891,9 @@ class Parser(metaclass=_Parser):
if self._match(TokenType.FOR):
args.append(self._parse_bitwise())
- this = exp.Substring.from_arg_list(args)
- self.validate_expression(this, args)
-
- return this
+ return self.validate_expression(exp.Substring.from_arg_list(args), args)
- def _parse_trim(self) -> exp.Expression:
+ def _parse_trim(self) -> exp.Trim:
# https://www.w3resource.com/sql/character-functions/trim.php
# https://docs.oracle.com/javadb/10.8.3.0/ref/rreftrimfunc.html
@@ -3885,11 +3914,7 @@ class Parser(metaclass=_Parser):
collation = self._parse_bitwise()
return self.expression(
- exp.Trim,
- this=this,
- position=position,
- expression=expression,
- collation=collation,
+ exp.Trim, this=this, position=position, expression=expression, collation=collation
)
def _parse_window_clause(self) -> t.Optional[t.List[t.Optional[exp.Expression]]]:
@@ -4047,7 +4072,7 @@ class Parser(metaclass=_Parser):
return self.PRIMARY_PARSERS[TokenType.STRING](self, self._prev)
return self._parse_placeholder()
- def _parse_string_as_identifier(self) -> t.Optional[exp.Expression]:
+ def _parse_string_as_identifier(self) -> t.Optional[exp.Identifier]:
return exp.to_identifier(self._match(TokenType.STRING) and self._prev.text, quoted=True)
def _parse_number(self) -> t.Optional[exp.Expression]:
@@ -4097,7 +4122,7 @@ class Parser(metaclass=_Parser):
return self.PRIMARY_PARSERS[TokenType.STAR](self, self._prev)
return None
- def _parse_parameter(self) -> exp.Expression:
+ def _parse_parameter(self) -> exp.Parameter:
wrapped = self._match(TokenType.L_BRACE)
this = self._parse_var() or self._parse_identifier() or self._parse_primary()
self._match(TokenType.R_BRACE)
@@ -4183,7 +4208,7 @@ class Parser(metaclass=_Parser):
self._parse_set_operations(self._parse_select(nested=True, parse_subquery_alias=False))
)
- def _parse_transaction(self) -> exp.Expression:
+ def _parse_transaction(self) -> exp.Transaction:
this = None
if self._match_texts(self.TRANSACTION_KIND):
this = self._prev.text
@@ -4203,7 +4228,7 @@ class Parser(metaclass=_Parser):
return self.expression(exp.Transaction, this=this, modes=modes)
- def _parse_commit_or_rollback(self) -> exp.Expression:
+ def _parse_commit_or_rollback(self) -> exp.Commit | exp.Rollback:
chain = None
savepoint = None
is_rollback = self._prev.token_type == TokenType.ROLLBACK
@@ -4220,6 +4245,7 @@ class Parser(metaclass=_Parser):
if is_rollback:
return self.expression(exp.Rollback, savepoint=savepoint)
+
return self.expression(exp.Commit, chain=chain)
def _parse_add_column(self) -> t.Optional[exp.Expression]:
@@ -4243,19 +4269,19 @@ class Parser(metaclass=_Parser):
return expression
- def _parse_drop_column(self) -> t.Optional[exp.Expression]:
+ def _parse_drop_column(self) -> t.Optional[exp.Drop | exp.Command]:
drop = self._match(TokenType.DROP) and self._parse_drop()
if drop and not isinstance(drop, exp.Command):
drop.set("kind", drop.args.get("kind", "COLUMN"))
return drop
# https://docs.aws.amazon.com/athena/latest/ug/alter-table-drop-partition.html
- def _parse_drop_partition(self, exists: t.Optional[bool] = None) -> exp.Expression:
+ def _parse_drop_partition(self, exists: t.Optional[bool] = None) -> exp.DropPartition:
return self.expression(
exp.DropPartition, expressions=self._parse_csv(self._parse_partition), exists=exists
)
- def _parse_add_constraint(self) -> t.Optional[exp.Expression]:
+ def _parse_add_constraint(self) -> exp.AddConstraint:
this = None
kind = self._prev.token_type
@@ -4288,7 +4314,7 @@ class Parser(metaclass=_Parser):
self._retreat(index)
return self._parse_csv(self._parse_add_column)
- def _parse_alter_table_alter(self) -> exp.Expression:
+ def _parse_alter_table_alter(self) -> exp.AlterColumn:
self._match(TokenType.COLUMN)
column = self._parse_field(any_token=True)
@@ -4316,11 +4342,11 @@ class Parser(metaclass=_Parser):
self._retreat(index)
return self._parse_csv(self._parse_drop_column)
- def _parse_alter_table_rename(self) -> exp.Expression:
+ def _parse_alter_table_rename(self) -> exp.RenameTable:
self._match_text_seq("TO")
return self.expression(exp.RenameTable, this=self._parse_table(schema=True))
- def _parse_alter(self) -> t.Optional[exp.Expression]:
+ def _parse_alter(self) -> exp.AlterTable | exp.Command:
start = self._prev
if not self._match(TokenType.TABLE):
@@ -4345,7 +4371,7 @@ class Parser(metaclass=_Parser):
)
return self._parse_as_command(start)
- def _parse_merge(self) -> exp.Expression:
+ def _parse_merge(self) -> exp.Merge:
self._match(TokenType.INTO)
target = self._parse_table()
@@ -4412,7 +4438,7 @@ class Parser(metaclass=_Parser):
)
def _parse_show(self) -> t.Optional[exp.Expression]:
- parser = self._find_parser(self.SHOW_PARSERS, self._show_trie) # type: ignore
+ parser = self._find_parser(self.SHOW_PARSERS, self.SHOW_TRIE)
if parser:
return parser(self)
self._advance()
@@ -4433,17 +4459,9 @@ class Parser(metaclass=_Parser):
return None
right = self._parse_statement() or self._parse_id_var()
- this = self.expression(
- exp.EQ,
- this=left,
- expression=right,
- )
+ this = self.expression(exp.EQ, this=left, expression=right)
- return self.expression(
- exp.SetItem,
- this=this,
- kind=kind,
- )
+ return self.expression(exp.SetItem, this=this, kind=kind)
def _parse_set_transaction(self, global_: bool = False) -> exp.Expression:
self._match_text_seq("TRANSACTION")
@@ -4458,10 +4476,10 @@ class Parser(metaclass=_Parser):
)
def _parse_set_item(self) -> t.Optional[exp.Expression]:
- parser = self._find_parser(self.SET_PARSERS, self._set_trie) # type: ignore
+ parser = self._find_parser(self.SET_PARSERS, self.SET_TRIE)
return parser(self) if parser else self._parse_set_item_assignment(kind=None)
- def _parse_set(self) -> exp.Expression:
+ def _parse_set(self) -> exp.Set | exp.Command:
index = self._index
set_ = self.expression(exp.Set, expressions=self._parse_csv(self._parse_set_item))
@@ -4471,10 +4489,10 @@ class Parser(metaclass=_Parser):
return set_
- def _parse_var_from_options(self, options: t.Collection[str]) -> t.Optional[exp.Expression]:
+ def _parse_var_from_options(self, options: t.Collection[str]) -> t.Optional[exp.Var]:
for option in options:
if self._match_text_seq(*option.split(" ")):
- return exp.Var(this=option)
+ return exp.var(option)
return None
def _parse_as_command(self, start: Token) -> exp.Command: