summaryrefslogtreecommitdiffstats
path: root/sqlglot/parser.py
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2023-06-02 23:59:40 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2023-06-02 23:59:46 +0000
commit20739a12c39121a9e7ad3c9a2469ec5a6876199d (patch)
treec000de91c59fd29b2d9beecf9f93b84e69727f37 /sqlglot/parser.py
parentReleasing debian version 12.2.0-1. (diff)
downloadsqlglot-20739a12c39121a9e7ad3c9a2469ec5a6876199d.tar.xz
sqlglot-20739a12c39121a9e7ad3c9a2469ec5a6876199d.zip
Merging upstream version 15.0.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/parser.py')
-rw-r--r--sqlglot/parser.py896
1 files changed, 525 insertions, 371 deletions
diff --git a/sqlglot/parser.py b/sqlglot/parser.py
index d8d9f88..e77bb5a 100644
--- a/sqlglot/parser.py
+++ b/sqlglot/parser.py
@@ -6,22 +6,17 @@ from collections import defaultdict
from sqlglot import exp
from sqlglot.errors import ErrorLevel, ParseError, concat_messages, merge_errors
-from sqlglot.helper import (
- apply_index_offset,
- count_params,
- ensure_collection,
- ensure_list,
- seq_get,
-)
+from sqlglot.helper import apply_index_offset, ensure_collection, ensure_list, seq_get
from sqlglot.tokens import Token, Tokenizer, TokenType
from sqlglot.trie import in_trie, new_trie
-logger = logging.getLogger("sqlglot")
+if t.TYPE_CHECKING:
+ from sqlglot._typing import E
-E = t.TypeVar("E", bound=exp.Expression)
+logger = logging.getLogger("sqlglot")
-def parse_var_map(args: t.Sequence) -> exp.Expression:
+def parse_var_map(args: t.List) -> exp.StarMap | exp.VarMap:
if len(args) == 1 and args[0].is_star:
return exp.StarMap(this=args[0])
@@ -36,7 +31,7 @@ def parse_var_map(args: t.Sequence) -> exp.Expression:
)
-def parse_like(args):
+def parse_like(args: t.List) -> exp.Expression:
like = exp.Like(this=seq_get(args, 1), expression=seq_get(args, 0))
return exp.Escape(this=like, expression=seq_get(args, 2)) if len(args) > 2 else like
@@ -65,7 +60,7 @@ class Parser(metaclass=_Parser):
Args:
error_level: the desired error level.
- Default: ErrorLevel.RAISE
+ Default: ErrorLevel.IMMEDIATE
error_message_context: determines the amount of context to capture from a
query string when displaying the error message (in number of characters).
Default: 50.
@@ -118,8 +113,8 @@ class Parser(metaclass=_Parser):
NESTED_TYPE_TOKENS = {
TokenType.ARRAY,
TokenType.MAP,
- TokenType.STRUCT,
TokenType.NULLABLE,
+ TokenType.STRUCT,
}
TYPE_TOKENS = {
@@ -158,6 +153,7 @@ class Parser(metaclass=_Parser):
TokenType.TIMESTAMPTZ,
TokenType.TIMESTAMPLTZ,
TokenType.DATETIME,
+ TokenType.DATETIME64,
TokenType.DATE,
TokenType.DECIMAL,
TokenType.BIGDECIMAL,
@@ -211,20 +207,18 @@ class Parser(metaclass=_Parser):
TokenType.VAR,
TokenType.ANTI,
TokenType.APPLY,
+ TokenType.ASC,
TokenType.AUTO_INCREMENT,
TokenType.BEGIN,
- TokenType.BOTH,
- TokenType.BUCKET,
TokenType.CACHE,
- TokenType.CASCADE,
TokenType.COLLATE,
TokenType.COMMAND,
TokenType.COMMENT,
TokenType.COMMIT,
- TokenType.COMPOUND,
TokenType.CONSTRAINT,
TokenType.DEFAULT,
TokenType.DELETE,
+ TokenType.DESC,
TokenType.DESCRIBE,
TokenType.DIV,
TokenType.END,
@@ -233,7 +227,6 @@ class Parser(metaclass=_Parser):
TokenType.FALSE,
TokenType.FIRST,
TokenType.FILTER,
- TokenType.FOLLOWING,
TokenType.FORMAT,
TokenType.FULL,
TokenType.IF,
@@ -241,41 +234,31 @@ class Parser(metaclass=_Parser):
TokenType.ISNULL,
TokenType.INTERVAL,
TokenType.KEEP,
- TokenType.LAZY,
- TokenType.LEADING,
TokenType.LEFT,
- TokenType.LOCAL,
- TokenType.MATERIALIZED,
+ TokenType.LOAD,
TokenType.MERGE,
TokenType.NATURAL,
TokenType.NEXT,
TokenType.OFFSET,
- TokenType.ONLY,
- TokenType.OPTIONS,
TokenType.ORDINALITY,
TokenType.OVERWRITE,
TokenType.PARTITION,
TokenType.PERCENT,
TokenType.PIVOT,
TokenType.PRAGMA,
- TokenType.PRECEDING,
TokenType.RANGE,
TokenType.REFERENCES,
TokenType.RIGHT,
TokenType.ROW,
TokenType.ROWS,
- TokenType.SEED,
TokenType.SEMI,
TokenType.SET,
+ TokenType.SETTINGS,
TokenType.SHOW,
- TokenType.SORTKEY,
TokenType.TEMPORARY,
TokenType.TOP,
- TokenType.TRAILING,
TokenType.TRUE,
- TokenType.UNBOUNDED,
TokenType.UNIQUE,
- TokenType.UNLOGGED,
TokenType.UNPIVOT,
TokenType.VOLATILE,
TokenType.WINDOW,
@@ -291,6 +274,7 @@ class Parser(metaclass=_Parser):
TokenType.APPLY,
TokenType.FULL,
TokenType.LEFT,
+ TokenType.LOCK,
TokenType.NATURAL,
TokenType.OFFSET,
TokenType.RIGHT,
@@ -301,7 +285,7 @@ class Parser(metaclass=_Parser):
UPDATE_ALIAS_TOKENS = TABLE_ALIAS_TOKENS - {TokenType.SET}
- TRIM_TYPES = {TokenType.LEADING, TokenType.TRAILING, TokenType.BOTH}
+ TRIM_TYPES = {"LEADING", "TRAILING", "BOTH"}
FUNC_TOKENS = {
TokenType.COMMAND,
@@ -322,6 +306,7 @@ class Parser(metaclass=_Parser):
TokenType.MERGE,
TokenType.OFFSET,
TokenType.PRIMARY_KEY,
+ TokenType.RANGE,
TokenType.REPLACE,
TokenType.ROW,
TokenType.UNNEST,
@@ -455,31 +440,31 @@ class Parser(metaclass=_Parser):
}
EXPRESSION_PARSERS = {
+ exp.Cluster: lambda self: self._parse_sort(exp.Cluster, "CLUSTER", "BY"),
exp.Column: lambda self: self._parse_column(),
+ exp.Condition: lambda self: self._parse_conjunction(),
exp.DataType: lambda self: self._parse_types(),
+ exp.Expression: lambda self: self._parse_statement(),
exp.From: lambda self: self._parse_from(),
exp.Group: lambda self: self._parse_group(),
+ exp.Having: lambda self: self._parse_having(),
exp.Identifier: lambda self: self._parse_id_var(),
- exp.Lateral: lambda self: self._parse_lateral(),
exp.Join: lambda self: self._parse_join(),
- exp.Order: lambda self: self._parse_order(),
- exp.Cluster: lambda self: self._parse_sort(TokenType.CLUSTER_BY, exp.Cluster),
- exp.Sort: lambda self: self._parse_sort(TokenType.SORT_BY, exp.Sort),
exp.Lambda: lambda self: self._parse_lambda(),
+ exp.Lateral: lambda self: self._parse_lateral(),
exp.Limit: lambda self: self._parse_limit(),
exp.Offset: lambda self: self._parse_offset(),
- exp.TableAlias: lambda self: self._parse_table_alias(),
- exp.Table: lambda self: self._parse_table(),
- exp.Condition: lambda self: self._parse_conjunction(),
- exp.Expression: lambda self: self._parse_statement(),
- exp.Properties: lambda self: self._parse_properties(),
- exp.Where: lambda self: self._parse_where(),
+ exp.Order: lambda self: self._parse_order(),
exp.Ordered: lambda self: self._parse_ordered(),
- exp.Having: lambda self: self._parse_having(),
- exp.With: lambda self: self._parse_with(),
- exp.Window: lambda self: self._parse_named_window(),
+ exp.Properties: lambda self: self._parse_properties(),
exp.Qualify: lambda self: self._parse_qualify(),
exp.Returning: lambda self: self._parse_returning(),
+ exp.Sort: lambda self: self._parse_sort(exp.Sort, "SORT", "BY"),
+ exp.Table: lambda self: self._parse_table_parts(),
+ exp.TableAlias: lambda self: self._parse_table_alias(),
+ exp.Where: lambda self: self._parse_where(),
+ exp.Window: lambda self: self._parse_named_window(),
+ exp.With: lambda self: self._parse_with(),
"JOIN_TYPE": lambda self: self._parse_join_side_and_kind(),
}
@@ -495,9 +480,13 @@ class Parser(metaclass=_Parser):
TokenType.DESCRIBE: lambda self: self._parse_describe(),
TokenType.DROP: lambda self: self._parse_drop(),
TokenType.END: lambda self: self._parse_commit_or_rollback(),
+ TokenType.FROM: lambda self: exp.select("*").from_(
+ t.cast(exp.From, self._parse_from(skip_from_token=True))
+ ),
TokenType.INSERT: lambda self: self._parse_insert(),
- TokenType.LOAD_DATA: lambda self: self._parse_load_data(),
+ TokenType.LOAD: lambda self: self._parse_load(),
TokenType.MERGE: lambda self: self._parse_merge(),
+ TokenType.PIVOT: lambda self: self._parse_simplified_pivot(),
TokenType.PRAGMA: lambda self: self.expression(exp.Pragma, this=self._parse_expression()),
TokenType.ROLLBACK: lambda self: self._parse_commit_or_rollback(),
TokenType.SET: lambda self: self._parse_set(),
@@ -536,7 +525,10 @@ class Parser(metaclass=_Parser):
TokenType.HEX_STRING: lambda self, token: self.expression(exp.HexString, this=token.text),
TokenType.BYTE_STRING: lambda self, token: self.expression(exp.ByteString, this=token.text),
TokenType.INTRODUCER: lambda self, token: self._parse_introducer(token),
- TokenType.NATIONAL: lambda self, token: self._parse_national(token),
+ TokenType.NATIONAL_STRING: lambda self, token: self.expression(
+ exp.National, this=token.text
+ ),
+ TokenType.RAW_STRING: lambda self, token: self.expression(exp.RawString, this=token.text),
TokenType.SESSION_PARAMETER: lambda self, _: self._parse_session_parameter(),
}
@@ -551,91 +543,76 @@ class Parser(metaclass=_Parser):
RANGE_PARSERS = {
TokenType.BETWEEN: lambda self, this: self._parse_between(this),
TokenType.GLOB: binary_range_parser(exp.Glob),
- TokenType.OVERLAPS: binary_range_parser(exp.Overlaps),
+ TokenType.ILIKE: binary_range_parser(exp.ILike),
TokenType.IN: lambda self, this: self._parse_in(this),
+ TokenType.IRLIKE: binary_range_parser(exp.RegexpILike),
TokenType.IS: lambda self, this: self._parse_is(this),
TokenType.LIKE: binary_range_parser(exp.Like),
- TokenType.ILIKE: binary_range_parser(exp.ILike),
- TokenType.IRLIKE: binary_range_parser(exp.RegexpILike),
+ TokenType.OVERLAPS: binary_range_parser(exp.Overlaps),
TokenType.RLIKE: binary_range_parser(exp.RegexpLike),
TokenType.SIMILAR_TO: binary_range_parser(exp.SimilarTo),
}
- PROPERTY_PARSERS = {
- "AFTER": lambda self: self._parse_afterjournal(
- no=self._prev.text.upper() == "NO", dual=self._prev.text.upper() == "DUAL"
- ),
+ PROPERTY_PARSERS: t.Dict[str, t.Callable] = {
"ALGORITHM": lambda self: self._parse_property_assignment(exp.AlgorithmProperty),
"AUTO_INCREMENT": lambda self: self._parse_property_assignment(exp.AutoIncrementProperty),
- "BEFORE": lambda self: self._parse_journal(
- no=self._prev.text.upper() == "NO", dual=self._prev.text.upper() == "DUAL"
- ),
"BLOCKCOMPRESSION": lambda self: self._parse_blockcompression(),
"CHARACTER SET": lambda self: self._parse_character_set(),
"CHECKSUM": lambda self: self._parse_checksum(),
- "CLUSTER BY": lambda self: self.expression(
- exp.Cluster, expressions=self._parse_csv(self._parse_ordered)
- ),
+ "CLUSTER": lambda self: self._parse_cluster(),
"COLLATE": lambda self: self._parse_property_assignment(exp.CollateProperty),
"COMMENT": lambda self: self._parse_property_assignment(exp.SchemaCommentProperty),
- "DATABLOCKSIZE": lambda self: self._parse_datablocksize(
- default=self._prev.text.upper() == "DEFAULT"
- ),
+ "DATABLOCKSIZE": lambda self, **kwargs: self._parse_datablocksize(**kwargs),
"DEFINER": lambda self: self._parse_definer(),
"DETERMINISTIC": lambda self: self.expression(
exp.StabilityProperty, this=exp.Literal.string("IMMUTABLE")
),
"DISTKEY": lambda self: self._parse_distkey(),
"DISTSTYLE": lambda self: self._parse_property_assignment(exp.DistStyleProperty),
+ "ENGINE": lambda self: self._parse_property_assignment(exp.EngineProperty),
"EXECUTE": lambda self: self._parse_property_assignment(exp.ExecuteAsProperty),
"EXTERNAL": lambda self: self.expression(exp.ExternalProperty),
- "FALLBACK": lambda self: self._parse_fallback(no=self._prev.text.upper() == "NO"),
+ "FALLBACK": lambda self, **kwargs: self._parse_fallback(**kwargs),
"FORMAT": lambda self: self._parse_property_assignment(exp.FileFormatProperty),
"FREESPACE": lambda self: self._parse_freespace(),
- "GLOBAL": lambda self: self._parse_temporary(global_=True),
"IMMUTABLE": lambda self: self.expression(
exp.StabilityProperty, this=exp.Literal.string("IMMUTABLE")
),
- "JOURNAL": lambda self: self._parse_journal(
- no=self._prev.text.upper() == "NO", dual=self._prev.text.upper() == "DUAL"
- ),
+ "JOURNAL": lambda self, **kwargs: self._parse_journal(**kwargs),
"LANGUAGE": lambda self: self._parse_property_assignment(exp.LanguageProperty),
"LIKE": lambda self: self._parse_create_like(),
- "LOCAL": lambda self: self._parse_afterjournal(no=False, dual=False, local=True),
"LOCATION": lambda self: self._parse_property_assignment(exp.LocationProperty),
"LOCK": lambda self: self._parse_locking(),
"LOCKING": lambda self: self._parse_locking(),
- "LOG": lambda self: self._parse_log(no=self._prev.text.upper() == "NO"),
+ "LOG": lambda self, **kwargs: self._parse_log(**kwargs),
"MATERIALIZED": lambda self: self.expression(exp.MaterializedProperty),
- "MAX": lambda self: self._parse_datablocksize(),
- "MAXIMUM": lambda self: self._parse_datablocksize(),
- "MERGEBLOCKRATIO": lambda self: self._parse_mergeblockratio(
- no=self._prev.text.upper() == "NO", default=self._prev.text.upper() == "DEFAULT"
- ),
- "MIN": lambda self: self._parse_datablocksize(),
- "MINIMUM": lambda self: self._parse_datablocksize(),
+ "MERGEBLOCKRATIO": lambda self, **kwargs: self._parse_mergeblockratio(**kwargs),
"MULTISET": lambda self: self.expression(exp.SetProperty, multi=True),
- "NO": lambda self: self._parse_noprimaryindex(),
- "NOT": lambda self: self._parse_afterjournal(no=False, dual=False, local=False),
- "ON": lambda self: self._parse_oncommit(),
+ "NO": lambda self: self._parse_no_property(),
+ "ON": lambda self: self._parse_on_property(),
+ "ORDER BY": lambda self: self._parse_order(skip_order_token=True),
"PARTITION BY": lambda self: self._parse_partitioned_by(),
"PARTITIONED BY": lambda self: self._parse_partitioned_by(),
"PARTITIONED_BY": lambda self: self._parse_partitioned_by(),
+ "PRIMARY KEY": lambda self: self._parse_primary_key(),
"RETURNS": lambda self: self._parse_returns(),
"ROW": lambda self: self._parse_row(),
"ROW_FORMAT": lambda self: self._parse_property_assignment(exp.RowFormatProperty),
"SET": lambda self: self.expression(exp.SetProperty, multi=False),
+ "SETTINGS": lambda self: self.expression(
+ exp.SettingsProperty, expressions=self._parse_csv(self._parse_set_item)
+ ),
"SORTKEY": lambda self: self._parse_sortkey(),
"STABLE": lambda self: self.expression(
exp.StabilityProperty, this=exp.Literal.string("STABLE")
),
"STORED": lambda self: self._parse_stored(),
- "TABLE_FORMAT": lambda self: self._parse_property_assignment(exp.TableFormatProperty),
"TBLPROPERTIES": lambda self: self._parse_wrapped_csv(self._parse_property),
- "TEMP": lambda self: self._parse_temporary(global_=False),
- "TEMPORARY": lambda self: self._parse_temporary(global_=False),
+ "TEMP": lambda self: self.expression(exp.TemporaryProperty),
+ "TEMPORARY": lambda self: self.expression(exp.TemporaryProperty),
"TRANSIENT": lambda self: self.expression(exp.TransientProperty),
- "USING": lambda self: self._parse_property_assignment(exp.TableFormatProperty),
+ "TTL": lambda self: self._parse_ttl(),
+ "USING": lambda self: self._parse_property_assignment(exp.FileFormatProperty),
"VOLATILE": lambda self: self._parse_volatile_property(),
"WITH": lambda self: self._parse_with_property(),
}
@@ -679,6 +656,7 @@ class Parser(metaclass=_Parser):
"TITLE": lambda self: self.expression(
exp.TitleColumnConstraint, this=self._parse_var_or_string()
),
+ "TTL": lambda self: self.expression(exp.MergeTreeTTL, expressions=[self._parse_bitwise()]),
"UNIQUE": lambda self: self._parse_unique(),
"UPPERCASE": lambda self: self.expression(exp.UppercaseColumnConstraint),
}
@@ -704,6 +682,8 @@ class Parser(metaclass=_Parser):
),
}
+ FUNCTIONS_WITH_ALIASED_ARGS = {"STRUCT"}
+
FUNCTION_PARSERS: t.Dict[str, t.Callable] = {
"CAST": lambda self: self._parse_cast(self.STRICT_CAST),
"CONVERT": lambda self: self._parse_convert(self.STRICT_CAST),
@@ -712,7 +692,9 @@ class Parser(metaclass=_Parser):
"JSON_OBJECT": lambda self: self._parse_json_object(),
"LOG": lambda self: self._parse_logarithm(),
"MATCH": lambda self: self._parse_match_against(),
+ "OPENJSON": lambda self: self._parse_open_json(),
"POSITION": lambda self: self._parse_position(),
+ "SAFE_CAST": lambda self: self._parse_cast(False),
"STRING_AGG": lambda self: self._parse_string_agg(),
"SUBSTRING": lambda self: self._parse_substring(),
"TRIM": lambda self: self._parse_trim(),
@@ -721,19 +703,18 @@ class Parser(metaclass=_Parser):
}
QUERY_MODIFIER_PARSERS = {
+ "joins": lambda self: list(iter(self._parse_join, None)),
+ "laterals": lambda self: list(iter(self._parse_lateral, None)),
"match": lambda self: self._parse_match_recognize(),
"where": lambda self: self._parse_where(),
"group": lambda self: self._parse_group(),
"having": lambda self: self._parse_having(),
"qualify": lambda self: self._parse_qualify(),
"windows": lambda self: self._parse_window_clause(),
- "distribute": lambda self: self._parse_sort(TokenType.DISTRIBUTE_BY, exp.Distribute),
- "sort": lambda self: self._parse_sort(TokenType.SORT_BY, exp.Sort),
- "cluster": lambda self: self._parse_sort(TokenType.CLUSTER_BY, exp.Cluster),
"order": lambda self: self._parse_order(),
"limit": lambda self: self._parse_limit(),
"offset": lambda self: self._parse_offset(),
- "lock": lambda self: self._parse_lock(),
+ "locks": lambda self: self._parse_locks(),
"sample": lambda self: self._parse_table_sample(as_modifier=True),
}
@@ -763,8 +744,11 @@ class Parser(metaclass=_Parser):
INSERT_ALTERNATIVES = {"ABORT", "FAIL", "IGNORE", "REPLACE", "ROLLBACK"}
+ CLONE_KINDS = {"TIMESTAMP", "OFFSET", "STATEMENT"}
+
WINDOW_ALIAS_TOKENS = ID_VAR_TOKENS - {TokenType.ROWS}
WINDOW_BEFORE_PAREN_TOKENS = {TokenType.OVER}
+ WINDOW_SIDES = {"FOLLOWING", "PRECEDING"}
ADD_CONSTRAINT_TOKENS = {TokenType.CONSTRAINT, TokenType.PRIMARY_KEY, TokenType.FOREIGN_KEY}
@@ -772,8 +756,8 @@ class Parser(metaclass=_Parser):
CONVERT_TYPE_FIRST = False
- QUOTED_PIVOT_COLUMNS: t.Optional[bool] = None
PREFIXED_PIVOT_COLUMNS = False
+ IDENTIFY_PIVOT_STRINGS = False
LOG_BASE_FIRST = True
LOG_DEFAULTS_TO_LN = False
@@ -875,7 +859,7 @@ class Parser(metaclass=_Parser):
e.errors[0]["into_expression"] = expression_type
errors.append(e)
raise ParseError(
- f"Failed to parse into {expression_types}",
+ f"Failed to parse '{sql or raw_tokens}' into {expression_types}",
errors=merge_errors(errors),
) from errors[-1]
@@ -933,7 +917,7 @@ class Parser(metaclass=_Parser):
"""
token = token or self._curr or self._prev or Token.string("")
start = token.start
- end = token.end
+ end = token.end + 1
start_context = self.sql[max(start - self.error_message_context, 0) : start]
highlight = self.sql[start:end]
end_context = self.sql[end : end + self.error_message_context]
@@ -996,7 +980,7 @@ class Parser(metaclass=_Parser):
self.raise_error(error_message)
def _find_sql(self, start: Token, end: Token) -> str:
- return self.sql[start.start : end.end]
+ return self.sql[start.start : end.end + 1]
def _advance(self, times: int = 1) -> None:
self._index += times
@@ -1042,6 +1026,44 @@ class Parser(metaclass=_Parser):
exp.Comment, this=this, kind=kind.text, expression=self._parse_string(), exists=exists
)
+ # https://clickhouse.com/docs/en/engines/table-engines/mergetree-family/mergetree#mergetree-table-ttl
+ def _parse_ttl(self) -> exp.Expression:
+ def _parse_ttl_action() -> t.Optional[exp.Expression]:
+ this = self._parse_bitwise()
+
+ if self._match_text_seq("DELETE"):
+ return self.expression(exp.MergeTreeTTLAction, this=this, delete=True)
+ if self._match_text_seq("RECOMPRESS"):
+ return self.expression(
+ exp.MergeTreeTTLAction, this=this, recompress=self._parse_bitwise()
+ )
+ if self._match_text_seq("TO", "DISK"):
+ return self.expression(
+ exp.MergeTreeTTLAction, this=this, to_disk=self._parse_string()
+ )
+ if self._match_text_seq("TO", "VOLUME"):
+ return self.expression(
+ exp.MergeTreeTTLAction, this=this, to_volume=self._parse_string()
+ )
+
+ return this
+
+ expressions = self._parse_csv(_parse_ttl_action)
+ where = self._parse_where()
+ group = self._parse_group()
+
+ aggregates = None
+ if group and self._match(TokenType.SET):
+ aggregates = self._parse_csv(self._parse_set_item)
+
+ return self.expression(
+ exp.MergeTreeTTL,
+ expressions=expressions,
+ where=where,
+ group=group,
+ aggregates=aggregates,
+ )
+
def _parse_statement(self) -> t.Optional[exp.Expression]:
if self._curr is None:
return None
@@ -1054,14 +1076,12 @@ class Parser(metaclass=_Parser):
expression = self._parse_expression()
expression = self._parse_set_operations(expression) if expression else self._parse_select()
-
- self._parse_query_modifiers(expression)
- return expression
+ return self._parse_query_modifiers(expression)
def _parse_drop(self) -> t.Optional[exp.Drop | exp.Command]:
start = self._prev
temporary = self._match(TokenType.TEMPORARY)
- materialized = self._match(TokenType.MATERIALIZED)
+ materialized = self._match_text_seq("MATERIALIZED")
kind = self._match_set(self.CREATABLES) and self._prev.text
if not kind:
return self._parse_as_command(start)
@@ -1073,7 +1093,7 @@ class Parser(metaclass=_Parser):
kind=kind,
temporary=temporary,
materialized=materialized,
- cascade=self._match(TokenType.CASCADE),
+ cascade=self._match_text_seq("CASCADE"),
constraints=self._match_text_seq("CONSTRAINTS"),
purge=self._match_text_seq("PURGE"),
)
@@ -1111,6 +1131,7 @@ class Parser(metaclass=_Parser):
indexes = None
no_schema_binding = None
begin = None
+ clone = None
if create_token.token_type in (TokenType.FUNCTION, TokenType.PROCEDURE):
this = self._parse_user_defined_function(kind=create_token.token_type)
@@ -1128,7 +1149,7 @@ class Parser(metaclass=_Parser):
if return_:
expression = self.expression(exp.Return, this=expression)
elif create_token.token_type == TokenType.INDEX:
- this = self._parse_index()
+ this = self._parse_index(index=self._parse_id_var())
elif create_token.token_type in self.DB_CREATABLES:
table_parts = self._parse_table_parts(schema=True)
@@ -1166,33 +1187,40 @@ class Parser(metaclass=_Parser):
expression = self._parse_ddl_select()
if create_token.token_type == TokenType.TABLE:
- # exp.Properties.Location.POST_EXPRESSION
- temp_properties = self._parse_properties()
- if properties and temp_properties:
- properties.expressions.extend(temp_properties.expressions)
- elif temp_properties:
- properties = temp_properties
-
indexes = []
while True:
- index = self._parse_create_table_index()
+ index = self._parse_index()
- # exp.Properties.Location.POST_INDEX
- if self._match(TokenType.PARTITION_BY, advance=False):
- temp_properties = self._parse_properties()
- if properties and temp_properties:
- properties.expressions.extend(temp_properties.expressions)
- elif temp_properties:
- properties = temp_properties
+ # exp.Properties.Location.POST_EXPRESSION or exp.Properties.Location.POST_INDEX
+ temp_properties = self._parse_properties()
+ if properties and temp_properties:
+ properties.expressions.extend(temp_properties.expressions)
+ elif temp_properties:
+ properties = temp_properties
if not index:
break
else:
+ self._match(TokenType.COMMA)
indexes.append(index)
elif create_token.token_type == TokenType.VIEW:
if self._match_text_seq("WITH", "NO", "SCHEMA", "BINDING"):
no_schema_binding = True
+ if self._match_text_seq("CLONE"):
+ clone = self._parse_table(schema=True)
+ when = self._match_texts({"AT", "BEFORE"}) and self._prev.text.upper()
+ clone_kind = (
+ self._match(TokenType.L_PAREN)
+ and self._match_texts(self.CLONE_KINDS)
+ and self._prev.text.upper()
+ )
+ clone_expression = self._match(TokenType.FARROW) and self._parse_bitwise()
+ self._match(TokenType.R_PAREN)
+ clone = self.expression(
+ exp.Clone, this=clone, when=when, kind=clone_kind, expression=clone_expression
+ )
+
return self.expression(
exp.Create,
this=this,
@@ -1205,18 +1233,31 @@ class Parser(metaclass=_Parser):
indexes=indexes,
no_schema_binding=no_schema_binding,
begin=begin,
+ clone=clone,
)
def _parse_property_before(self) -> t.Optional[exp.Expression]:
+ # only used for teradata currently
self._match(TokenType.COMMA)
- # parsers look to _prev for no/dual/default, so need to consume first
- self._match_text_seq("NO")
- self._match_text_seq("DUAL")
- self._match_text_seq("DEFAULT")
+ kwargs = {
+ "no": self._match_text_seq("NO"),
+ "dual": self._match_text_seq("DUAL"),
+ "before": self._match_text_seq("BEFORE"),
+ "default": self._match_text_seq("DEFAULT"),
+ "local": (self._match_text_seq("LOCAL") and "LOCAL")
+ or (self._match_text_seq("NOT", "LOCAL") and "NOT LOCAL"),
+ "after": self._match_text_seq("AFTER"),
+ "minimum": self._match_texts(("MIN", "MINIMUM")),
+ "maximum": self._match_texts(("MAX", "MAXIMUM")),
+ }
- if self.PROPERTY_PARSERS.get(self._curr.text.upper()):
- return self.PROPERTY_PARSERS[self._curr.text.upper()](self)
+ if self._match_texts(self.PROPERTY_PARSERS):
+ parser = self.PROPERTY_PARSERS[self._prev.text.upper()]
+ try:
+ return parser(self, **{k: v for k, v in kwargs.items() if v})
+ except TypeError:
+ self.raise_error(f"Cannot parse property '{self._prev.text}'")
return None
@@ -1227,7 +1268,7 @@ class Parser(metaclass=_Parser):
if self._match_pair(TokenType.DEFAULT, TokenType.CHARACTER_SET):
return self._parse_character_set(default=True)
- if self._match_pair(TokenType.COMPOUND, TokenType.SORTKEY):
+ if self._match_text_seq("COMPOUND", "SORTKEY"):
return self._parse_sortkey(compound=True)
if self._match_text_seq("SQL", "SECURITY"):
@@ -1262,23 +1303,20 @@ class Parser(metaclass=_Parser):
def _parse_property_assignment(self, exp_class: t.Type[exp.Expression]) -> exp.Expression:
self._match(TokenType.EQ)
self._match(TokenType.ALIAS)
- return self.expression(
- exp_class,
- this=self._parse_var_or_string() or self._parse_number() or self._parse_id_var(),
- )
+ return self.expression(exp_class, this=self._parse_field())
- def _parse_properties(self, before=None) -> t.Optional[exp.Expression]:
+ def _parse_properties(self, before: t.Optional[bool] = None) -> t.Optional[exp.Expression]:
properties = []
while True:
if before:
- identified_property = self._parse_property_before()
+ prop = self._parse_property_before()
else:
- identified_property = self._parse_property()
+ prop = self._parse_property()
- if not identified_property:
+ if not prop:
break
- for p in ensure_list(identified_property):
+ for p in ensure_list(prop):
properties.append(p)
if properties:
@@ -1286,8 +1324,7 @@ class Parser(metaclass=_Parser):
return None
- def _parse_fallback(self, no=False) -> exp.Expression:
- self._match_text_seq("FALLBACK")
+ def _parse_fallback(self, no: bool = False) -> exp.Expression:
return self.expression(
exp.FallbackProperty, no=no, protection=self._match_text_seq("PROTECTION")
)
@@ -1345,23 +1382,13 @@ class Parser(metaclass=_Parser):
self._match(TokenType.EQ)
return self.expression(exp.WithJournalTableProperty, this=self._parse_table_parts())
- def _parse_log(self, no=False) -> exp.Expression:
- self._match_text_seq("LOG")
+ def _parse_log(self, no: bool = False) -> exp.Expression:
return self.expression(exp.LogProperty, no=no)
- def _parse_journal(self, no=False, dual=False) -> exp.Expression:
- before = self._match_text_seq("BEFORE")
- self._match_text_seq("JOURNAL")
- return self.expression(exp.JournalProperty, no=no, dual=dual, before=before)
-
- def _parse_afterjournal(self, no=False, dual=False, local=None) -> exp.Expression:
- self._match_text_seq("NOT")
- self._match_text_seq("LOCAL")
- self._match_text_seq("AFTER", "JOURNAL")
- return self.expression(exp.AfterJournalProperty, no=no, dual=dual, local=local)
+ def _parse_journal(self, **kwargs) -> exp.Expression:
+ return self.expression(exp.JournalProperty, **kwargs)
def _parse_checksum(self) -> exp.Expression:
- self._match_text_seq("CHECKSUM")
self._match(TokenType.EQ)
on = None
@@ -1377,49 +1404,55 @@ class Parser(metaclass=_Parser):
default=default,
)
+ def _parse_cluster(self) -> t.Optional[exp.Expression]:
+ if not self._match_text_seq("BY"):
+ self._retreat(self._index - 1)
+ return None
+ return self.expression(
+ exp.Cluster,
+ expressions=self._parse_csv(self._parse_ordered),
+ )
+
def _parse_freespace(self) -> exp.Expression:
- self._match_text_seq("FREESPACE")
self._match(TokenType.EQ)
return self.expression(
exp.FreespaceProperty, this=self._parse_number(), percent=self._match(TokenType.PERCENT)
)
- def _parse_mergeblockratio(self, no=False, default=False) -> exp.Expression:
- self._match_text_seq("MERGEBLOCKRATIO")
+ def _parse_mergeblockratio(self, no: bool = False, default: bool = False) -> exp.Expression:
if self._match(TokenType.EQ):
return self.expression(
exp.MergeBlockRatioProperty,
this=self._parse_number(),
percent=self._match(TokenType.PERCENT),
)
- else:
- return self.expression(
- exp.MergeBlockRatioProperty,
- no=no,
- default=default,
- )
+ return self.expression(
+ exp.MergeBlockRatioProperty,
+ no=no,
+ default=default,
+ )
- def _parse_datablocksize(self, default=None) -> exp.Expression:
- if default:
- self._match_text_seq("DATABLOCKSIZE")
- return self.expression(exp.DataBlocksizeProperty, default=True)
- elif self._match_texts(("MIN", "MINIMUM")):
- self._match_text_seq("DATABLOCKSIZE")
- return self.expression(exp.DataBlocksizeProperty, min=True)
- elif self._match_texts(("MAX", "MAXIMUM")):
- self._match_text_seq("DATABLOCKSIZE")
- return self.expression(exp.DataBlocksizeProperty, min=False)
-
- self._match_text_seq("DATABLOCKSIZE")
+ def _parse_datablocksize(
+ self,
+ default: t.Optional[bool] = None,
+ minimum: t.Optional[bool] = None,
+ maximum: t.Optional[bool] = None,
+ ) -> exp.Expression:
self._match(TokenType.EQ)
size = self._parse_number()
units = None
if self._match_texts(("BYTES", "KBYTES", "KILOBYTES")):
units = self._prev.text
- return self.expression(exp.DataBlocksizeProperty, size=size, units=units)
+ return self.expression(
+ exp.DataBlocksizeProperty,
+ size=size,
+ units=units,
+ default=default,
+ minimum=minimum,
+ maximum=maximum,
+ )
def _parse_blockcompression(self) -> exp.Expression:
- self._match_text_seq("BLOCKCOMPRESSION")
self._match(TokenType.EQ)
always = self._match_text_seq("ALWAYS")
manual = self._match_text_seq("MANUAL")
@@ -1516,7 +1549,7 @@ class Parser(metaclass=_Parser):
this=self._parse_schema() or self._parse_bracket(self._parse_field()),
)
- def _parse_withdata(self, no=False) -> exp.Expression:
+ def _parse_withdata(self, no: bool = False) -> exp.Expression:
if self._match_text_seq("AND", "STATISTICS"):
statistics = True
elif self._match_text_seq("AND", "NO", "STATISTICS"):
@@ -1526,13 +1559,17 @@ class Parser(metaclass=_Parser):
return self.expression(exp.WithDataProperty, no=no, statistics=statistics)
- def _parse_noprimaryindex(self) -> exp.Expression:
- self._match_text_seq("PRIMARY", "INDEX")
- return exp.NoPrimaryIndexProperty()
+ def _parse_no_property(self) -> t.Optional[exp.Property]:
+ if self._match_text_seq("PRIMARY", "INDEX"):
+ return exp.NoPrimaryIndexProperty()
+ return None
- def _parse_oncommit(self) -> exp.Expression:
- self._match_text_seq("COMMIT", "PRESERVE", "ROWS")
- return exp.OnCommitProperty()
+ def _parse_on_property(self) -> t.Optional[exp.Property]:
+ if self._match_text_seq("COMMIT", "PRESERVE", "ROWS"):
+ return exp.OnCommitProperty()
+ elif self._match_text_seq("COMMIT", "DELETE", "ROWS"):
+ return exp.OnCommitProperty(delete=True)
+ return None
def _parse_distkey(self) -> exp.Expression:
return self.expression(exp.DistKeyProperty, this=self._parse_wrapped(self._parse_id_var))
@@ -1587,10 +1624,6 @@ class Parser(metaclass=_Parser):
return self.expression(exp.ReturnsProperty, this=value, is_table=is_table)
- def _parse_temporary(self, global_=False) -> exp.Expression:
- self._match(TokenType.TEMPORARY) # in case calling from "GLOBAL"
- return self.expression(exp.TemporaryProperty, global_=global_)
-
def _parse_describe(self) -> exp.Expression:
kind = self._match_set(self.CREATABLES) and self._prev.text
this = self._parse_table()
@@ -1599,7 +1632,7 @@ class Parser(metaclass=_Parser):
def _parse_insert(self) -> exp.Expression:
overwrite = self._match(TokenType.OVERWRITE)
- local = self._match(TokenType.LOCAL)
+ local = self._match_text_seq("LOCAL")
alternative = None
if self._match_text_seq("DIRECTORY"):
@@ -1700,23 +1733,25 @@ class Parser(metaclass=_Parser):
return self.expression(exp.RowFormatDelimitedProperty, **kwargs) # type: ignore
- def _parse_load_data(self) -> exp.Expression:
- local = self._match(TokenType.LOCAL)
- self._match_text_seq("INPATH")
- inpath = self._parse_string()
- overwrite = self._match(TokenType.OVERWRITE)
- self._match_pair(TokenType.INTO, TokenType.TABLE)
+ def _parse_load(self) -> exp.Expression:
+ if self._match_text_seq("DATA"):
+ local = self._match_text_seq("LOCAL")
+ self._match_text_seq("INPATH")
+ inpath = self._parse_string()
+ overwrite = self._match(TokenType.OVERWRITE)
+ self._match_pair(TokenType.INTO, TokenType.TABLE)
- return self.expression(
- exp.LoadData,
- this=self._parse_table(schema=True),
- local=local,
- overwrite=overwrite,
- inpath=inpath,
- partition=self._parse_partition(),
- input_format=self._match_text_seq("INPUTFORMAT") and self._parse_string(),
- serde=self._match_text_seq("SERDE") and self._parse_string(),
- )
+ return self.expression(
+ exp.LoadData,
+ this=self._parse_table(schema=True),
+ local=local,
+ overwrite=overwrite,
+ inpath=inpath,
+ partition=self._parse_partition(),
+ input_format=self._match_text_seq("INPUTFORMAT") and self._parse_string(),
+ serde=self._match_text_seq("SERDE") and self._parse_string(),
+ )
+ return self._parse_as_command(self._prev)
def _parse_delete(self) -> exp.Expression:
self._match(TokenType.FROM)
@@ -1735,7 +1770,7 @@ class Parser(metaclass=_Parser):
**{ # type: ignore
"this": self._parse_table(alias_tokens=self.UPDATE_ALIAS_TOKENS),
"expressions": self._match(TokenType.SET) and self._parse_csv(self._parse_equality),
- "from": self._parse_from(),
+ "from": self._parse_from(modifiers=True),
"where": self._parse_where(),
"returning": self._parse_returning(),
},
@@ -1752,12 +1787,12 @@ class Parser(metaclass=_Parser):
)
def _parse_cache(self) -> exp.Expression:
- lazy = self._match(TokenType.LAZY)
+ lazy = self._match_text_seq("LAZY")
self._match(TokenType.TABLE)
table = self._parse_table(schema=True)
options = []
- if self._match(TokenType.OPTIONS):
+ if self._match_text_seq("OPTIONS"):
self._match_l_paren()
k = self._parse_string()
self._match(TokenType.EQ)
@@ -1851,11 +1886,10 @@ class Parser(metaclass=_Parser):
if from_:
this.set("from", from_)
- self._parse_query_modifiers(this)
+ this = self._parse_query_modifiers(this)
elif (table or nested) and self._match(TokenType.L_PAREN):
this = self._parse_table() if table else self._parse_select(nested=True)
- self._parse_query_modifiers(this)
- this = self._parse_set_operations(this)
+ this = self._parse_set_operations(self._parse_query_modifiers(this))
self._match_r_paren()
# early return so that subquery unions aren't parsed again
@@ -1868,6 +1902,10 @@ class Parser(metaclass=_Parser):
expressions=self._parse_csv(self._parse_value),
alias=self._parse_table_alias(),
)
+ elif self._match(TokenType.PIVOT):
+ this = self._parse_simplified_pivot()
+ elif self._match(TokenType.FROM):
+ this = exp.select("*").from_(t.cast(exp.From, self._parse_from(skip_from_token=True)))
else:
this = None
@@ -1929,7 +1967,9 @@ class Parser(metaclass=_Parser):
def _parse_subquery(
self, this: t.Optional[exp.Expression], parse_alias: bool = True
- ) -> exp.Expression:
+ ) -> t.Optional[exp.Expression]:
+ if not this:
+ return None
return self.expression(
exp.Subquery,
this=this,
@@ -1937,35 +1977,16 @@ class Parser(metaclass=_Parser):
alias=self._parse_table_alias() if parse_alias else None,
)
- def _parse_query_modifiers(self, this: t.Optional[exp.Expression]) -> None:
- if not isinstance(this, self.MODIFIABLES):
- return
-
- table = isinstance(this, exp.Table)
-
- while True:
- join = self._parse_join()
- if join:
- this.append("joins", join)
-
- lateral = None
- if not join:
- lateral = self._parse_lateral()
- if lateral:
- this.append("laterals", lateral)
-
- comma = None if table else self._match(TokenType.COMMA)
- if comma:
- this.args["from"].append("expressions", self._parse_table())
-
- if not (lateral or join or comma):
- break
-
- for key, parser in self.QUERY_MODIFIER_PARSERS.items():
- expression = parser(self)
+ def _parse_query_modifiers(
+ self, this: t.Optional[exp.Expression]
+ ) -> t.Optional[exp.Expression]:
+ if isinstance(this, self.MODIFIABLES):
+ for key, parser in self.QUERY_MODIFIER_PARSERS.items():
+ expression = parser(self)
- if expression:
- this.set(key, expression)
+ if expression:
+ this.set(key, expression)
+ return this
def _parse_hint(self) -> t.Optional[exp.Expression]:
if self._match(TokenType.HINT):
@@ -1981,19 +2002,26 @@ class Parser(metaclass=_Parser):
return None
temp = self._match(TokenType.TEMPORARY)
- unlogged = self._match(TokenType.UNLOGGED)
+ unlogged = self._match_text_seq("UNLOGGED")
self._match(TokenType.TABLE)
return self.expression(
exp.Into, this=self._parse_table(schema=True), temporary=temp, unlogged=unlogged
)
- def _parse_from(self) -> t.Optional[exp.Expression]:
- if not self._match(TokenType.FROM):
+ def _parse_from(
+ self, modifiers: bool = False, skip_from_token: bool = False
+ ) -> t.Optional[exp.From]:
+ if not skip_from_token and not self._match(TokenType.FROM):
return None
+ comments = self._prev_comments
+ this = self._parse_table()
+
return self.expression(
- exp.From, comments=self._prev_comments, expressions=self._parse_csv(self._parse_table)
+ exp.From,
+ comments=comments,
+ this=self._parse_query_modifiers(this) if modifiers else this,
)
def _parse_match_recognize(self) -> t.Optional[exp.Expression]:
@@ -2136,6 +2164,9 @@ class Parser(metaclass=_Parser):
)
def _parse_join(self, skip_join_token: bool = False) -> t.Optional[exp.Expression]:
+ if self._match(TokenType.COMMA):
+ return self.expression(exp.Join, this=self._parse_table())
+
index = self._index
natural, side, kind = self._parse_join_side_and_kind()
hint = self._prev.text if self._match_texts(self.JOIN_HINTS) else None
@@ -2176,55 +2207,66 @@ class Parser(metaclass=_Parser):
return self.expression(exp.Join, **kwargs) # type: ignore
- def _parse_index(self) -> exp.Expression:
- index = self._parse_id_var()
- self._match(TokenType.ON)
- self._match(TokenType.TABLE) # hive
+ def _parse_index(
+ self,
+ index: t.Optional[exp.Expression] = None,
+ ) -> t.Optional[exp.Expression]:
+ if index:
+ unique = None
+ primary = None
+ amp = None
- return self.expression(
- exp.Index,
- this=index,
- table=self.expression(exp.Table, this=self._parse_id_var()),
- columns=self._parse_expression(),
- )
+ self._match(TokenType.ON)
+ self._match(TokenType.TABLE) # hive
+ table = self._parse_table_parts(schema=True)
+ else:
+ unique = self._match(TokenType.UNIQUE)
+ primary = self._match_text_seq("PRIMARY")
+ amp = self._match_text_seq("AMP")
+ if not self._match(TokenType.INDEX):
+ return None
+ index = self._parse_id_var()
+ table = None
- def _parse_create_table_index(self) -> t.Optional[exp.Expression]:
- unique = self._match(TokenType.UNIQUE)
- primary = self._match_text_seq("PRIMARY")
- amp = self._match_text_seq("AMP")
- if not self._match(TokenType.INDEX):
- return None
- index = self._parse_id_var()
- columns = None
if self._match(TokenType.L_PAREN, advance=False):
- columns = self._parse_wrapped_csv(self._parse_column)
+ columns = self._parse_wrapped_csv(self._parse_ordered)
+ else:
+ columns = None
+
return self.expression(
exp.Index,
this=index,
+ table=table,
columns=columns,
unique=unique,
primary=primary,
amp=amp,
+ partition_by=self._parse_partition_by(),
)
- def _parse_table_parts(self, schema: bool = False) -> exp.Expression:
- catalog = None
- db = None
-
- table = (
+ def _parse_table_part(self, schema: bool = False) -> t.Optional[exp.Expression]:
+ return (
(not schema and self._parse_function())
or self._parse_id_var(any_token=False)
or self._parse_string_as_identifier()
+ or self._parse_placeholder()
)
+ def _parse_table_parts(self, schema: bool = False) -> exp.Table:
+ catalog = None
+ db = None
+ table = self._parse_table_part(schema=schema)
+
while self._match(TokenType.DOT):
if catalog:
# This allows nesting the table in arbitrarily many dot expressions if needed
- table = self.expression(exp.Dot, this=table, expression=self._parse_id_var())
+ table = self.expression(
+ exp.Dot, this=table, expression=self._parse_table_part(schema=schema)
+ )
else:
catalog = db
db = table
- table = self._parse_id_var()
+ table = self._parse_table_part(schema=schema)
if not table:
self.raise_error(f"Expected table name but got {self._curr}")
@@ -2237,28 +2279,24 @@ class Parser(metaclass=_Parser):
self, schema: bool = False, alias_tokens: t.Optional[t.Collection[TokenType]] = None
) -> t.Optional[exp.Expression]:
lateral = self._parse_lateral()
-
if lateral:
return lateral
unnest = self._parse_unnest()
-
if unnest:
return unnest
values = self._parse_derived_table_values()
-
if values:
return values
subquery = self._parse_select(table=True)
-
if subquery:
if not subquery.args.get("pivots"):
subquery.set("pivots", self._parse_pivots())
return subquery
- this = self._parse_table_parts(schema=schema)
+ this: exp.Expression = self._parse_table_parts(schema=schema)
if schema:
return self._parse_schema(this=this)
@@ -2267,7 +2305,6 @@ class Parser(metaclass=_Parser):
table_sample = self._parse_table_sample()
alias = self._parse_table_alias(alias_tokens=alias_tokens or self.TABLE_ALIAS_TOKENS)
-
if alias:
this.set("alias", alias)
@@ -2352,9 +2389,9 @@ class Parser(metaclass=_Parser):
num = self._parse_number()
- if self._match(TokenType.BUCKET):
+ if self._match_text_seq("BUCKET"):
bucket_numerator = self._parse_number()
- self._match(TokenType.OUT_OF)
+ self._match_text_seq("OUT", "OF")
bucket_denominator = bucket_denominator = self._parse_number()
self._match(TokenType.ON)
bucket_field = self._parse_field()
@@ -2390,6 +2427,22 @@ class Parser(metaclass=_Parser):
def _parse_pivots(self) -> t.List[t.Optional[exp.Expression]]:
return list(iter(self._parse_pivot, None))
+ # https://duckdb.org/docs/sql/statements/pivot
+ def _parse_simplified_pivot(self) -> exp.Pivot:
+ def _parse_on() -> t.Optional[exp.Expression]:
+ this = self._parse_bitwise()
+ return self._parse_in(this) if self._match(TokenType.IN) else this
+
+ this = self._parse_table()
+ expressions = self._match(TokenType.ON) and self._parse_csv(_parse_on)
+ using = self._match(TokenType.USING) and self._parse_csv(
+ lambda: self._parse_alias(self._parse_function())
+ )
+ group = self._parse_group()
+ return self.expression(
+ exp.Pivot, this=this, expressions=expressions, using=using, group=group
+ )
+
def _parse_pivot(self) -> t.Optional[exp.Expression]:
index = self._index
@@ -2423,7 +2476,7 @@ class Parser(metaclass=_Parser):
if not self._match(TokenType.IN):
self.raise_error("Expecting IN")
- field = self._parse_in(value)
+ field = self._parse_in(value, alias=True)
self._match_r_paren()
@@ -2436,21 +2489,22 @@ class Parser(metaclass=_Parser):
names = self._pivot_column_names(t.cast(t.List[exp.Expression], expressions))
columns: t.List[exp.Expression] = []
- for col in pivot.args["field"].expressions:
+ for fld in pivot.args["field"].expressions:
+ field_name = fld.sql() if self.IDENTIFY_PIVOT_STRINGS else fld.alias_or_name
for name in names:
if self.PREFIXED_PIVOT_COLUMNS:
- name = f"{name}_{col.alias_or_name}" if name else col.alias_or_name
+ name = f"{name}_{field_name}" if name else field_name
else:
- name = f"{col.alias_or_name}_{name}" if name else col.alias_or_name
+ name = f"{field_name}_{name}" if name else field_name
- columns.append(exp.to_identifier(name, quoted=self.QUOTED_PIVOT_COLUMNS))
+ columns.append(exp.to_identifier(name))
pivot.set("columns", columns)
return pivot
- def _pivot_column_names(self, pivot_columns: t.List[exp.Expression]) -> t.List[str]:
- return [agg.alias for agg in pivot_columns]
+ def _pivot_column_names(self, aggregations: t.List[exp.Expression]) -> t.List[str]:
+ return [agg.alias for agg in aggregations]
def _parse_where(self, skip_where_token: bool = False) -> t.Optional[exp.Expression]:
if not skip_where_token and not self._match(TokenType.WHERE):
@@ -2477,6 +2531,7 @@ class Parser(metaclass=_Parser):
rollup = None
cube = None
+ totals = None
with_ = self._match(TokenType.WITH)
if self._match(TokenType.ROLLUP):
@@ -2487,7 +2542,11 @@ class Parser(metaclass=_Parser):
cube = with_ or self._parse_wrapped_csv(self._parse_column)
elements["cube"].extend(ensure_list(cube))
- if not (expressions or grouping_sets or rollup or cube):
+ if self._match_text_seq("TOTALS"):
+ totals = True
+ elements["totals"] = True # type: ignore
+
+ if not (grouping_sets or rollup or cube or totals):
break
return self.expression(exp.Group, **elements) # type: ignore
@@ -2527,9 +2586,9 @@ class Parser(metaclass=_Parser):
)
def _parse_sort(
- self, token_type: TokenType, exp_class: t.Type[exp.Expression]
+ self, exp_class: t.Type[exp.Expression], *texts: str
) -> t.Optional[exp.Expression]:
- if not self._match(token_type):
+ if not self._match_text_seq(*texts):
return None
return self.expression(exp_class, expressions=self._parse_csv(self._parse_ordered))
@@ -2537,8 +2596,8 @@ class Parser(metaclass=_Parser):
this = self._parse_conjunction()
self._match(TokenType.ASC)
is_desc = self._match(TokenType.DESC)
- is_nulls_first = self._match(TokenType.NULLS_FIRST)
- is_nulls_last = self._match(TokenType.NULLS_LAST)
+ is_nulls_first = self._match_text_seq("NULLS", "FIRST")
+ is_nulls_last = self._match_text_seq("NULLS", "LAST")
desc = is_desc or False
asc = not desc
nulls_first = is_nulls_first or False
@@ -2578,7 +2637,7 @@ class Parser(metaclass=_Parser):
self._match_set((TokenType.ROW, TokenType.ROWS))
- only = self._match(TokenType.ONLY)
+ only = self._match_text_seq("ONLY")
with_ties = self._match_text_seq("WITH", "TIES")
if only and with_ties:
@@ -2602,13 +2661,37 @@ class Parser(metaclass=_Parser):
self._match_set((TokenType.ROW, TokenType.ROWS))
return self.expression(exp.Offset, this=this, expression=count)
- def _parse_lock(self) -> t.Optional[exp.Expression]:
- if self._match_text_seq("FOR", "UPDATE"):
- return self.expression(exp.Lock, update=True)
- if self._match_text_seq("FOR", "SHARE"):
- return self.expression(exp.Lock, update=False)
+ def _parse_locks(self) -> t.List[exp.Expression]:
+ # Lists are invariant, so we need to use a type hint here
+ locks: t.List[exp.Expression] = []
- return None
+ while True:
+ if self._match_text_seq("FOR", "UPDATE"):
+ update = True
+ elif self._match_text_seq("FOR", "SHARE") or self._match_text_seq(
+ "LOCK", "IN", "SHARE", "MODE"
+ ):
+ update = False
+ else:
+ break
+
+ expressions = None
+ if self._match_text_seq("OF"):
+ expressions = self._parse_csv(lambda: self._parse_table(schema=True))
+
+ wait: t.Optional[bool | exp.Expression] = None
+ if self._match_text_seq("NOWAIT"):
+ wait = True
+ elif self._match_text_seq("WAIT"):
+ wait = self._parse_primary()
+ elif self._match_text_seq("SKIP", "LOCKED"):
+ wait = False
+
+ locks.append(
+ self.expression(exp.Lock, update=update, expressions=expressions, wait=wait)
+ )
+
+ return locks
def _parse_set_operations(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
if not self._match_set(self.SET_OPERATIONS):
@@ -2672,7 +2755,7 @@ class Parser(metaclass=_Parser):
def _parse_is(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
index = self._index - 1
negate = self._match(TokenType.NOT)
- if self._match(TokenType.DISTINCT_FROM):
+ if self._match_text_seq("DISTINCT", "FROM"):
klass = exp.NullSafeEQ if negate else exp.NullSafeNEQ
return self.expression(klass, this=this, expression=self._parse_expression())
@@ -2684,12 +2767,12 @@ class Parser(metaclass=_Parser):
this = self.expression(exp.Is, this=this, expression=expression)
return self.expression(exp.Not, this=this) if negate else this
- def _parse_in(self, this: t.Optional[exp.Expression]) -> exp.Expression:
+ def _parse_in(self, this: t.Optional[exp.Expression], alias: bool = False) -> exp.In:
unnest = self._parse_unnest()
if unnest:
this = self.expression(exp.In, this=this, unnest=unnest)
elif self._match(TokenType.L_PAREN):
- expressions = self._parse_csv(self._parse_select_or_expression)
+ expressions = self._parse_csv(lambda: self._parse_select_or_expression(alias=alias))
if len(expressions) == 1 and isinstance(expressions[0], exp.Subqueryable):
this = self.expression(exp.In, this=this, query=expressions[0])
@@ -2722,15 +2805,19 @@ class Parser(metaclass=_Parser):
# Most dialects support, e.g., the form INTERVAL '5' day, thus we try to parse
# each INTERVAL expression into this canonical form so it's easy to transpile
- if this and isinstance(this, exp.Literal):
- if this.is_number:
- this = exp.Literal.string(this.name)
-
- # Try to not clutter Snowflake's multi-part intervals like INTERVAL '1 day, 1 year'
+ if this and this.is_number:
+ this = exp.Literal.string(this.name)
+ elif this and this.is_string:
parts = this.name.split()
- if not unit and len(parts) <= 2:
- this = exp.Literal.string(seq_get(parts, 0))
- unit = self.expression(exp.Var, this=seq_get(parts, 1))
+
+ if len(parts) == 2:
+ if unit:
+ # this is not actually a unit, it's something else
+ unit = None
+ self._retreat(self._index - 1)
+ else:
+ this = exp.Literal.string(parts[0])
+ unit = self.expression(exp.Var, this=parts[1])
return self.expression(exp.Interval, this=this, unit=unit)
@@ -2783,13 +2870,22 @@ class Parser(metaclass=_Parser):
if parser:
return parser(self, this, data_type)
return self.expression(exp.Cast, this=this, to=data_type)
- if not data_type.args.get("expressions"):
+ if not data_type.expressions:
self._retreat(index)
return self._parse_column()
- return data_type
+ return self._parse_column_ops(data_type)
return this
+ def _parse_type_size(self) -> t.Optional[exp.Expression]:
+ this = self._parse_type()
+ if not this:
+ return None
+
+ return self.expression(
+ exp.DataTypeSize, this=this, expression=self._parse_var(any_token=True)
+ )
+
def _parse_types(self, check_func: bool = False) -> t.Optional[exp.Expression]:
index = self._index
@@ -2814,7 +2910,7 @@ class Parser(metaclass=_Parser):
elif nested:
expressions = self._parse_csv(self._parse_types)
else:
- expressions = self._parse_csv(self._parse_conjunction)
+ expressions = self._parse_csv(self._parse_type_size)
if not expressions or not self._match(TokenType.R_PAREN):
self._retreat(index)
@@ -2858,13 +2954,14 @@ class Parser(metaclass=_Parser):
value: t.Optional[exp.Expression] = None
if type_token in self.TIMESTAMPS:
- if self._match(TokenType.WITH_TIME_ZONE) or type_token == TokenType.TIMESTAMPTZ:
+ if self._match_text_seq("WITH", "TIME", "ZONE") or type_token == TokenType.TIMESTAMPTZ:
value = exp.DataType(this=exp.DataType.Type.TIMESTAMPTZ, expressions=expressions)
elif (
- self._match(TokenType.WITH_LOCAL_TIME_ZONE) or type_token == TokenType.TIMESTAMPLTZ
+ self._match_text_seq("WITH", "LOCAL", "TIME", "ZONE")
+ or type_token == TokenType.TIMESTAMPLTZ
):
value = exp.DataType(this=exp.DataType.Type.TIMESTAMPLTZ, expressions=expressions)
- elif self._match(TokenType.WITHOUT_TIME_ZONE):
+ elif self._match_text_seq("WITHOUT", "TIME", "ZONE"):
if type_token == TokenType.TIME:
value = exp.DataType(this=exp.DataType.Type.TIME, expressions=expressions)
else:
@@ -2909,7 +3006,7 @@ class Parser(metaclass=_Parser):
return self._parse_column_def(this)
def _parse_at_time_zone(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
- if not self._match(TokenType.AT_TIME_ZONE):
+ if not self._match_text_seq("AT", "TIME", "ZONE"):
return this
return self.expression(exp.AtTimeZone, this=this, zone=self._parse_unary())
@@ -2919,6 +3016,9 @@ class Parser(metaclass=_Parser):
this = self.expression(exp.Column, this=this)
elif not this:
return self._parse_bracket(this)
+ return self._parse_column_ops(this)
+
+ def _parse_column_ops(self, this: exp.Expression) -> exp.Expression:
this = self._parse_bracket(this)
while self._match_set(self.COLUMN_OPERATORS):
@@ -2929,7 +3029,7 @@ class Parser(metaclass=_Parser):
field = self._parse_types()
if not field:
self.raise_error("Expected type")
- elif op:
+ elif op and self._curr:
self._advance()
value = self._prev.text
field = (
@@ -2963,7 +3063,6 @@ class Parser(metaclass=_Parser):
else:
this = self.expression(exp.Dot, this=this, expression=field)
this = self._parse_bracket(this)
-
return this
def _parse_primary(self) -> t.Optional[exp.Expression]:
@@ -2989,12 +3088,9 @@ class Parser(metaclass=_Parser):
if query:
expressions = [query]
else:
- expressions = self._parse_csv(
- lambda: self._parse_alias(self._parse_conjunction(), explicit=True)
- )
+ expressions = self._parse_csv(self._parse_expression)
- this = seq_get(expressions, 0)
- self._parse_query_modifiers(this)
+ this = self._parse_query_modifiers(seq_get(expressions, 0))
if isinstance(this, exp.Subqueryable):
this = self._parse_set_operations(
@@ -3065,20 +3161,12 @@ class Parser(metaclass=_Parser):
functions = self.FUNCTIONS
function = functions.get(upper)
- args = self._parse_csv(self._parse_lambda)
- if function and not anonymous:
- # Clickhouse supports function calls like foo(x, y)(z), so for these we need to also parse the
- # second parameter list (i.e. "(z)") and the corresponding function will receive both arg lists.
- if count_params(function) == 2:
- params = None
- if self._match_pair(TokenType.R_PAREN, TokenType.L_PAREN):
- params = self._parse_csv(self._parse_lambda)
-
- this = function(args, params)
- else:
- this = function(args)
+ alias = upper in self.FUNCTIONS_WITH_ALIASED_ARGS
+ args = self._parse_csv(lambda: self._parse_lambda(alias=alias))
+ if function and not anonymous:
+ this = function(args)
self.validate_expression(this, args)
else:
this = self.expression(exp.Anonymous, this=this, expressions=args)
@@ -3113,9 +3201,6 @@ class Parser(metaclass=_Parser):
return self.expression(exp.Identifier, this=token.text)
- def _parse_national(self, token: Token) -> exp.Expression:
- return self.expression(exp.National, this=exp.Literal.string(token.text))
-
def _parse_session_parameter(self) -> exp.Expression:
kind = None
this = self._parse_id_var() or self._parse_primary()
@@ -3126,7 +3211,7 @@ class Parser(metaclass=_Parser):
return self.expression(exp.SessionParameter, this=this, kind=kind)
- def _parse_lambda(self) -> t.Optional[exp.Expression]:
+ def _parse_lambda(self, alias: bool = False) -> t.Optional[exp.Expression]:
index = self._index
if self._match(TokenType.L_PAREN):
@@ -3149,7 +3234,7 @@ class Parser(metaclass=_Parser):
exp.Distinct, expressions=self._parse_csv(self._parse_conjunction)
)
else:
- this = self._parse_select_or_expression()
+ this = self._parse_select_or_expression(alias=alias)
if isinstance(this, exp.EQ):
left = this.this
@@ -3161,13 +3246,15 @@ class Parser(metaclass=_Parser):
def _parse_schema(self, this: t.Optional[exp.Expression] = None) -> t.Optional[exp.Expression]:
index = self._index
- try:
- if self._parse_select(nested=True):
- return this
- except Exception:
- pass
- finally:
- self._retreat(index)
+ if not self.errors:
+ try:
+ if self._parse_select(nested=True):
+ return this
+ except ParseError:
+ pass
+ finally:
+ self.errors.clear()
+ self._retreat(index)
if not self._match(TokenType.L_PAREN):
return this
@@ -3227,13 +3314,18 @@ class Parser(metaclass=_Parser):
return self.expression(exp.CompressColumnConstraint, this=self._parse_bitwise())
def _parse_generated_as_identity(self) -> exp.Expression:
- if self._match(TokenType.BY_DEFAULT):
- this = self.expression(exp.GeneratedAsIdentityColumnConstraint, this=False)
+ if self._match_text_seq("BY", "DEFAULT"):
+ on_null = self._match_pair(TokenType.ON, TokenType.NULL)
+ this = self.expression(
+ exp.GeneratedAsIdentityColumnConstraint, this=False, on_null=on_null
+ )
else:
self._match_text_seq("ALWAYS")
this = self.expression(exp.GeneratedAsIdentityColumnConstraint, this=True)
- self._match_text_seq("AS", "IDENTITY")
+ self._match(TokenType.ALIAS)
+ identity = self._match_text_seq("IDENTITY")
+
if self._match(TokenType.L_PAREN):
if self._match_text_seq("START", "WITH"):
this.set("start", self._parse_bitwise())
@@ -3249,6 +3341,9 @@ class Parser(metaclass=_Parser):
elif self._match_text_seq("NO", "CYCLE"):
this.set("cycle", False)
+ if not identity:
+ this.set("expression", self._parse_bitwise())
+
self._match_r_paren()
return this
@@ -3307,9 +3402,10 @@ class Parser(metaclass=_Parser):
return self.CONSTRAINT_PARSERS[constraint](self)
def _parse_unique(self) -> exp.Expression:
- if not self._match(TokenType.L_PAREN, advance=False):
- return self.expression(exp.UniqueColumnConstraint)
- return self.expression(exp.Unique, expressions=self._parse_wrapped_id_vars())
+ self._match_text_seq("KEY")
+ return self.expression(
+ exp.UniqueColumnConstraint, this=self._parse_schema(self._parse_id_var(any_token=False))
+ )
def _parse_key_constraint_options(self) -> t.List[str]:
options = []
@@ -3321,9 +3417,9 @@ class Parser(metaclass=_Parser):
action = None
on = self._advance_any() and self._prev.text
- if self._match(TokenType.NO_ACTION):
+ if self._match_text_seq("NO", "ACTION"):
action = "NO ACTION"
- elif self._match(TokenType.CASCADE):
+ elif self._match_text_seq("CASCADE"):
action = "CASCADE"
elif self._match_pair(TokenType.SET, TokenType.NULL):
action = "SET NULL"
@@ -3348,7 +3444,7 @@ class Parser(metaclass=_Parser):
return options
- def _parse_references(self, match=True) -> t.Optional[exp.Expression]:
+ def _parse_references(self, match: bool = True) -> t.Optional[exp.Expression]:
if match and not self._match(TokenType.REFERENCES):
return None
@@ -3372,7 +3468,7 @@ class Parser(metaclass=_Parser):
kind = self._prev.text.lower()
- if self._match(TokenType.NO_ACTION):
+ if self._match_text_seq("NO", "ACTION"):
action = "NO ACTION"
elif self._match(TokenType.SET):
self._match_set((TokenType.NULL, TokenType.DEFAULT))
@@ -3396,11 +3492,19 @@ class Parser(metaclass=_Parser):
if not self._match(TokenType.L_PAREN, advance=False):
return self.expression(exp.PrimaryKeyColumnConstraint, desc=desc)
- expressions = self._parse_wrapped_id_vars()
+ expressions = self._parse_wrapped_csv(self._parse_field)
options = self._parse_key_constraint_options()
return self.expression(exp.PrimaryKey, expressions=expressions, options=options)
+ @t.overload
+ def _parse_bracket(self, this: exp.Expression) -> exp.Expression:
+ ...
+
+ @t.overload
def _parse_bracket(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
+ ...
+
+ def _parse_bracket(self, this):
if not self._match_set((TokenType.L_BRACKET, TokenType.L_BRACE)):
return this
@@ -3493,7 +3597,12 @@ class Parser(metaclass=_Parser):
this = self._parse_conjunction()
if not self._match(TokenType.ALIAS):
- self.raise_error("Expected AS after CAST")
+ if self._match(TokenType.COMMA):
+ return self.expression(
+ exp.CastToStrType, this=this, expression=self._parse_string()
+ )
+ else:
+ self.raise_error("Expected AS after CAST")
to = self._parse_types()
@@ -3524,7 +3633,7 @@ class Parser(metaclass=_Parser):
# Checks if we can parse an order clause: WITHIN GROUP (ORDER BY <order_by_expression_list> [ASC | DESC]).
# This is done "manually", instead of letting _parse_window parse it into an exp.WithinGroup node, so that
# the STRING_AGG call is parsed like in MySQL / SQLite and can thus be transpiled more easily to them.
- if not self._match(TokenType.WITHIN_GROUP):
+ if not self._match_text_seq("WITHIN", "GROUP"):
self._retreat(index)
this = exp.GroupConcat.from_arg_list(args)
self.validate_expression(this, args)
@@ -3674,6 +3783,27 @@ class Parser(metaclass=_Parser):
exp.MatchAgainst, this=this, expressions=expressions, modifier=modifier
)
+ # https://learn.microsoft.com/en-us/sql/t-sql/functions/openjson-transact-sql?view=sql-server-ver16
+ def _parse_open_json(self) -> exp.Expression:
+ this = self._parse_bitwise()
+ path = self._match(TokenType.COMMA) and self._parse_string()
+
+ def _parse_open_json_column_def() -> exp.Expression:
+ this = self._parse_field(any_token=True)
+ kind = self._parse_types()
+ path = self._parse_string()
+ as_json = self._match_pair(TokenType.ALIAS, TokenType.JSON)
+ return self.expression(
+ exp.OpenJSONColumnDef, this=this, kind=kind, path=path, as_json=as_json
+ )
+
+ expressions = None
+ if self._match_pair(TokenType.R_PAREN, TokenType.WITH):
+ self._match_l_paren()
+ expressions = self._parse_csv(_parse_open_json_column_def)
+
+ return self.expression(exp.OpenJSON, this=this, path=path, expressions=expressions)
+
def _parse_position(self, haystack_first: bool = False) -> exp.Expression:
args = self._parse_csv(self._parse_bitwise)
@@ -3722,7 +3852,7 @@ class Parser(metaclass=_Parser):
position = None
collation = None
- if self._match_set(self.TRIM_TYPES):
+ if self._match_texts(self.TRIM_TYPES):
position = self._prev.text.upper()
expression = self._parse_bitwise()
@@ -3752,9 +3882,9 @@ class Parser(metaclass=_Parser):
def _parse_respect_or_ignore_nulls(
self, this: t.Optional[exp.Expression]
) -> t.Optional[exp.Expression]:
- if self._match(TokenType.IGNORE_NULLS):
+ if self._match_text_seq("IGNORE", "NULLS"):
return self.expression(exp.IgnoreNulls, this=this)
- if self._match(TokenType.RESPECT_NULLS):
+ if self._match_text_seq("RESPECT", "NULLS"):
return self.expression(exp.RespectNulls, this=this)
return this
@@ -3767,7 +3897,7 @@ class Parser(metaclass=_Parser):
# T-SQL allows the OVER (...) syntax after WITHIN GROUP.
# https://learn.microsoft.com/en-us/sql/t-sql/functions/percentile-disc-transact-sql?view=sql-server-ver16
- if self._match(TokenType.WITHIN_GROUP):
+ if self._match_text_seq("WITHIN", "GROUP"):
order = self._parse_wrapped(self._parse_order)
this = self.expression(exp.WithinGroup, this=this, expression=order)
@@ -3846,10 +3976,11 @@ class Parser(metaclass=_Parser):
return {
"value": (
- self._match_set((TokenType.UNBOUNDED, TokenType.CURRENT_ROW)) and self._prev.text
- )
- or self._parse_bitwise(),
- "side": self._match_set((TokenType.PRECEDING, TokenType.FOLLOWING)) and self._prev.text,
+ (self._match_text_seq("UNBOUNDED") and "UNBOUNDED")
+ or (self._match_text_seq("CURRENT", "ROW") and "CURRENT ROW")
+ or self._parse_bitwise()
+ ),
+ "side": self._match_texts(self.WINDOW_SIDES) and self._prev.text,
}
def _parse_alias(
@@ -3956,7 +4087,7 @@ class Parser(metaclass=_Parser):
def _parse_parameter(self) -> exp.Expression:
wrapped = self._match(TokenType.L_BRACE)
- this = self._parse_var() or self._parse_primary()
+ this = self._parse_var() or self._parse_identifier() or self._parse_primary()
self._match(TokenType.R_BRACE)
return self.expression(exp.Parameter, this=this, wrapped=wrapped)
@@ -4011,26 +4142,33 @@ class Parser(metaclass=_Parser):
return this
- def _parse_wrapped_id_vars(self) -> t.List[t.Optional[exp.Expression]]:
- return self._parse_wrapped_csv(self._parse_id_var)
+ def _parse_wrapped_id_vars(self, optional: bool = False) -> t.List[t.Optional[exp.Expression]]:
+ return self._parse_wrapped_csv(self._parse_id_var, optional=optional)
def _parse_wrapped_csv(
- self, parse_method: t.Callable, sep: TokenType = TokenType.COMMA
+ self, parse_method: t.Callable, sep: TokenType = TokenType.COMMA, optional: bool = False
) -> t.List[t.Optional[exp.Expression]]:
- return self._parse_wrapped(lambda: self._parse_csv(parse_method, sep=sep))
+ return self._parse_wrapped(
+ lambda: self._parse_csv(parse_method, sep=sep), optional=optional
+ )
- def _parse_wrapped(self, parse_method: t.Callable) -> t.Any:
- self._match_l_paren()
+ def _parse_wrapped(self, parse_method: t.Callable, optional: bool = False) -> t.Any:
+ wrapped = self._match(TokenType.L_PAREN)
+ if not wrapped and not optional:
+ self.raise_error("Expecting (")
parse_result = parse_method()
- self._match_r_paren()
+ if wrapped:
+ self._match_r_paren()
return parse_result
- def _parse_select_or_expression(self) -> t.Optional[exp.Expression]:
- return self._parse_select() or self._parse_set_operations(self._parse_expression())
+ def _parse_select_or_expression(self, alias: bool = False) -> t.Optional[exp.Expression]:
+ return self._parse_select() or self._parse_set_operations(
+ self._parse_expression() if alias else self._parse_conjunction()
+ )
def _parse_ddl_select(self) -> t.Optional[exp.Expression]:
- return self._parse_set_operations(
- self._parse_select(nested=True, parse_subquery_alias=False)
+ return self._parse_query_modifiers(
+ self._parse_set_operations(self._parse_select(nested=True, parse_subquery_alias=False))
)
def _parse_transaction(self) -> exp.Expression:
@@ -4391,11 +4529,11 @@ class Parser(metaclass=_Parser):
return None
- def _match_l_paren(self, expression=None):
+ def _match_l_paren(self, expression: t.Optional[exp.Expression] = None) -> None:
if not self._match(TokenType.L_PAREN, expression=expression):
self.raise_error("Expecting (")
- def _match_r_paren(self, expression=None):
+ def _match_r_paren(self, expression: t.Optional[exp.Expression] = None) -> None:
if not self._match(TokenType.R_PAREN, expression=expression):
self.raise_error("Expecting )")
@@ -4420,6 +4558,16 @@ class Parser(metaclass=_Parser):
return True
+ @t.overload
+ def _replace_columns_with_dots(self, this: exp.Expression) -> exp.Expression:
+ ...
+
+ @t.overload
+ def _replace_columns_with_dots(
+ self, this: t.Optional[exp.Expression]
+ ) -> t.Optional[exp.Expression]:
+ ...
+
def _replace_columns_with_dots(self, this):
if isinstance(this, exp.Dot):
exp.replace_children(this, self._replace_columns_with_dots)
@@ -4433,9 +4581,15 @@ class Parser(metaclass=_Parser):
)
elif isinstance(this, exp.Identifier):
this = self.expression(exp.Var, this=this.name)
+
return this
- def _replace_lambda(self, node, lambda_variables):
+ def _replace_lambda(
+ self, node: t.Optional[exp.Expression], lambda_variables: t.Set[str]
+ ) -> t.Optional[exp.Expression]:
+ if not node:
+ return node
+
for column in node.find_all(exp.Column):
if column.parts[0].name in lambda_variables:
dot_or_id = column.to_dot() if column.table else column.this