summaryrefslogtreecommitdiffstats
path: root/sqlglot/parser.py
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-01-31 05:44:41 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-01-31 05:44:41 +0000
commit376de8b6892deca7dc5d83035c047f1e13eb67ea (patch)
tree334a1753cd914294aa99128fac3fb59bf14dc10f /sqlglot/parser.py
parentReleasing debian version 20.9.0-1. (diff)
downloadsqlglot-376de8b6892deca7dc5d83035c047f1e13eb67ea.tar.xz
sqlglot-376de8b6892deca7dc5d83035c047f1e13eb67ea.zip
Merging upstream version 20.11.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/parser.py')
-rw-r--r--sqlglot/parser.py161
1 files changed, 120 insertions, 41 deletions
diff --git a/sqlglot/parser.py b/sqlglot/parser.py
index 790ee0d..c091605 100644
--- a/sqlglot/parser.py
+++ b/sqlglot/parser.py
@@ -12,9 +12,7 @@ from sqlglot.tokens import Token, Tokenizer, TokenType
from sqlglot.trie import TrieResult, in_trie, new_trie
if t.TYPE_CHECKING:
- from typing_extensions import Literal
-
- from sqlglot._typing import E
+ from sqlglot._typing import E, Lit
from sqlglot.dialects.dialect import Dialect, DialectType
logger = logging.getLogger("sqlglot")
@@ -148,6 +146,11 @@ class Parser(metaclass=_Parser):
TokenType.ENUM16,
}
+ AGGREGATE_TYPE_TOKENS = {
+ TokenType.AGGREGATEFUNCTION,
+ TokenType.SIMPLEAGGREGATEFUNCTION,
+ }
+
TYPE_TOKENS = {
TokenType.BIT,
TokenType.BOOLEAN,
@@ -241,6 +244,7 @@ class Parser(metaclass=_Parser):
TokenType.NULL,
*ENUM_TYPE_TOKENS,
*NESTED_TYPE_TOKENS,
+ *AGGREGATE_TYPE_TOKENS,
}
SIGNED_TO_UNSIGNED_TYPE_TOKEN = {
@@ -653,9 +657,11 @@ class Parser(metaclass=_Parser):
PLACEHOLDER_PARSERS = {
TokenType.PLACEHOLDER: lambda self: self.expression(exp.Placeholder),
TokenType.PARAMETER: lambda self: self._parse_parameter(),
- TokenType.COLON: lambda self: self.expression(exp.Placeholder, this=self._prev.text)
- if self._match(TokenType.NUMBER) or self._match_set(self.ID_VAR_TOKENS)
- else None,
+ TokenType.COLON: lambda self: (
+ self.expression(exp.Placeholder, this=self._prev.text)
+ if self._match(TokenType.NUMBER) or self._match_set(self.ID_VAR_TOKENS)
+ else None
+ ),
}
RANGE_PARSERS = {
@@ -705,6 +711,9 @@ class Parser(metaclass=_Parser):
"IMMUTABLE": lambda self: self.expression(
exp.StabilityProperty, this=exp.Literal.string("IMMUTABLE")
),
+ "INHERITS": lambda self: self.expression(
+ exp.InheritsProperty, expressions=self._parse_wrapped_csv(self._parse_table)
+ ),
"INPUT": lambda self: self.expression(exp.InputModelProperty, this=self._parse_schema()),
"JOURNAL": lambda self, **kwargs: self._parse_journal(**kwargs),
"LANGUAGE": lambda self: self._parse_property_assignment(exp.LanguageProperty),
@@ -822,6 +831,7 @@ class Parser(metaclass=_Parser):
ALTER_PARSERS = {
"ADD": lambda self: self._parse_alter_table_add(),
"ALTER": lambda self: self._parse_alter_table_alter(),
+ "CLUSTER BY": lambda self: self._parse_cluster(wrapped=True),
"DELETE": lambda self: self.expression(exp.Delete, where=self._parse_where()),
"DROP": lambda self: self._parse_alter_table_drop(),
"RENAME": lambda self: self._parse_alter_table_rename(),
@@ -973,6 +983,9 @@ class Parser(metaclass=_Parser):
MODIFIERS_ATTACHED_TO_UNION = True
UNION_MODIFIERS = {"order", "limit", "offset"}
+ # parses no parenthesis if statements as commands
+ NO_PAREN_IF_COMMANDS = True
+
__slots__ = (
"error_level",
"error_message_context",
@@ -1207,7 +1220,20 @@ class Parser(metaclass=_Parser):
if index != self._index:
self._advance(index - self._index)
+ def _warn_unsupported(self) -> None:
+ if len(self._tokens) <= 1:
+ return
+
+ # We use _find_sql because self.sql may comprise multiple chunks, and we're only
+ # interested in emitting a warning for the one being currently processed.
+ sql = self._find_sql(self._tokens[0], self._tokens[-1])[: self.error_message_context]
+
+ logger.warning(
+ f"'{sql}' contains unsupported syntax. Falling back to parsing as a 'Command'."
+ )
+
def _parse_command(self) -> exp.Command:
+ self._warn_unsupported()
return self.expression(
exp.Command, this=self._prev.text.upper(), expression=self._parse_string()
)
@@ -1329,8 +1355,10 @@ class Parser(metaclass=_Parser):
start = self._prev
comments = self._prev_comments
- replace = start.text.upper() == "REPLACE" or self._match_pair(
- TokenType.OR, TokenType.REPLACE
+ replace = (
+ start.token_type == TokenType.REPLACE
+ or self._match_pair(TokenType.OR, TokenType.REPLACE)
+ or self._match_pair(TokenType.OR, TokenType.ALTER)
)
unique = self._match(TokenType.UNIQUE)
@@ -1440,6 +1468,9 @@ class Parser(metaclass=_Parser):
exp.Clone, this=self._parse_table(schema=True), shallow=shallow, copy=copy
)
+ if self._curr:
+ return self._parse_as_command(start)
+
return self.expression(
exp.Create,
comments=comments,
@@ -1516,11 +1547,13 @@ class Parser(metaclass=_Parser):
return self.expression(
exp.FileFormatProperty,
- this=self.expression(
- exp.InputOutputFormat, input_format=input_format, output_format=output_format
- )
- if input_format or output_format
- else self._parse_var_or_string() or self._parse_number() or self._parse_id_var(),
+ this=(
+ self.expression(
+ exp.InputOutputFormat, input_format=input_format, output_format=output_format
+ )
+ if input_format or output_format
+ else self._parse_var_or_string() or self._parse_number() or self._parse_id_var()
+ ),
)
def _parse_property_assignment(self, exp_class: t.Type[E], **kwargs: t.Any) -> E:
@@ -1632,8 +1665,15 @@ class Parser(metaclass=_Parser):
return self.expression(exp.ChecksumProperty, on=on, default=self._match(TokenType.DEFAULT))
- def _parse_cluster(self) -> exp.Cluster:
- return self.expression(exp.Cluster, expressions=self._parse_csv(self._parse_ordered))
+ def _parse_cluster(self, wrapped: bool = False) -> exp.Cluster:
+ return self.expression(
+ exp.Cluster,
+ expressions=(
+ self._parse_wrapped_csv(self._parse_ordered)
+ if wrapped
+ else self._parse_csv(self._parse_ordered)
+ ),
+ )
def _parse_clustered_by(self) -> exp.ClusteredByProperty:
self._match_text_seq("BY")
@@ -2681,6 +2721,8 @@ class Parser(metaclass=_Parser):
else:
columns = None
+ include = self._parse_wrapped_id_vars() if self._match_text_seq("INCLUDE") else None
+
return self.expression(
exp.Index,
this=index,
@@ -2690,6 +2732,7 @@ class Parser(metaclass=_Parser):
unique=unique,
primary=primary,
amp=amp,
+ include=include,
partition_by=self._parse_partition_by(),
where=self._parse_where(),
)
@@ -3380,8 +3423,8 @@ class Parser(metaclass=_Parser):
def _parse_comparison(self) -> t.Optional[exp.Expression]:
return self._parse_tokens(self._parse_range, self.COMPARISON)
- def _parse_range(self) -> t.Optional[exp.Expression]:
- this = self._parse_bitwise()
+ def _parse_range(self, this: t.Optional[exp.Expression] = None) -> t.Optional[exp.Expression]:
+ this = this or self._parse_bitwise()
negate = self._match(TokenType.NOT)
if self._match_set(self.RANGE_PARSERS):
@@ -3535,14 +3578,21 @@ class Parser(metaclass=_Parser):
return self._parse_tokens(self._parse_factor, self.TERM)
def _parse_factor(self) -> t.Optional[exp.Expression]:
- if self.EXPONENT:
- factor = self._parse_tokens(self._parse_exponent, self.FACTOR)
- else:
- factor = self._parse_tokens(self._parse_unary, self.FACTOR)
- if isinstance(factor, exp.Div):
- factor.args["typed"] = self.dialect.TYPED_DIVISION
- factor.args["safe"] = self.dialect.SAFE_DIVISION
- return factor
+ parse_method = self._parse_exponent if self.EXPONENT else self._parse_unary
+ this = parse_method()
+
+ while self._match_set(self.FACTOR):
+ this = self.expression(
+ self.FACTOR[self._prev.token_type],
+ this=this,
+ comments=self._prev_comments,
+ expression=parse_method(),
+ )
+ if isinstance(this, exp.Div):
+ this.args["typed"] = self.dialect.TYPED_DIVISION
+ this.args["safe"] = self.dialect.SAFE_DIVISION
+
+ return this
def _parse_exponent(self) -> t.Optional[exp.Expression]:
return self._parse_tokens(self._parse_unary, self.EXPONENT)
@@ -3617,6 +3667,7 @@ class Parser(metaclass=_Parser):
return exp.DataType.build(type_name, udt=True)
else:
+ self._retreat(self._index - 1)
return None
else:
return None
@@ -3631,6 +3682,7 @@ class Parser(metaclass=_Parser):
nested = type_token in self.NESTED_TYPE_TOKENS
is_struct = type_token in self.STRUCT_TYPE_TOKENS
+ is_aggregate = type_token in self.AGGREGATE_TYPE_TOKENS
expressions = None
maybe_func = False
@@ -3645,6 +3697,18 @@ class Parser(metaclass=_Parser):
)
elif type_token in self.ENUM_TYPE_TOKENS:
expressions = self._parse_csv(self._parse_equality)
+ elif is_aggregate:
+ func_or_ident = self._parse_function(anonymous=True) or self._parse_id_var(
+ any_token=False, tokens=(TokenType.VAR,)
+ )
+ if not func_or_ident or not self._match(TokenType.COMMA):
+ return None
+ expressions = self._parse_csv(
+ lambda: self._parse_types(
+ check_func=check_func, schema=schema, allow_identifiers=allow_identifiers
+ )
+ )
+ expressions.insert(0, func_or_ident)
else:
expressions = self._parse_csv(self._parse_type_size)
@@ -4413,6 +4477,10 @@ class Parser(metaclass=_Parser):
self._match_r_paren()
else:
index = self._index - 1
+
+ if self.NO_PAREN_IF_COMMANDS and index == 0:
+ return self._parse_as_command(self._prev)
+
condition = self._parse_conjunction()
if not condition:
@@ -4624,12 +4692,10 @@ class Parser(metaclass=_Parser):
return None
@t.overload
- def _parse_json_object(self, agg: Literal[False]) -> exp.JSONObject:
- ...
+ def _parse_json_object(self, agg: Lit[False]) -> exp.JSONObject: ...
@t.overload
- def _parse_json_object(self, agg: Literal[True]) -> exp.JSONObjectAgg:
- ...
+ def _parse_json_object(self, agg: Lit[True]) -> exp.JSONObjectAgg: ...
def _parse_json_object(self, agg=False):
star = self._parse_star()
@@ -4974,11 +5040,12 @@ class Parser(metaclass=_Parser):
if alias:
this = self.expression(exp.Alias, comments=comments, this=this, alias=alias)
+ column = this.this
# Moves the comment next to the alias in `expr /* comment */ AS alias`
- if not this.comments and this.this.comments:
- this.comments = this.this.comments
- this.this.comments = None
+ if not this.comments and column and column.comments:
+ this.comments = column.comments
+ column.comments = None
return this
@@ -5244,7 +5311,7 @@ class Parser(metaclass=_Parser):
if self._match_text_seq("CHECK"):
expression = self._parse_wrapped(self._parse_conjunction)
- enforced = self._match_text_seq("ENFORCED")
+ enforced = self._match_text_seq("ENFORCED") or False
return self.expression(
exp.AddConstraint, this=this, expression=expression, enforced=enforced
@@ -5278,6 +5345,8 @@ class Parser(metaclass=_Parser):
return self.expression(exp.AlterColumn, this=column, drop=True)
if self._match_pair(TokenType.SET, TokenType.DEFAULT):
return self.expression(exp.AlterColumn, this=column, default=self._parse_conjunction())
+ if self._match(TokenType.COMMENT):
+ return self.expression(exp.AlterColumn, this=column, comment=self._parse_string())
self._match_text_seq("SET", "DATA")
return self.expression(
@@ -5298,7 +5367,18 @@ class Parser(metaclass=_Parser):
self._retreat(index)
return self._parse_csv(self._parse_drop_column)
- def _parse_alter_table_rename(self) -> exp.RenameTable:
+ def _parse_alter_table_rename(self) -> t.Optional[exp.RenameTable | exp.RenameColumn]:
+ if self._match(TokenType.COLUMN):
+ exists = self._parse_exists()
+ old_column = self._parse_column()
+ to = self._match_text_seq("TO")
+ new_column = self._parse_column()
+
+ if old_column is None or to is None or new_column is None:
+ return None
+
+ return self.expression(exp.RenameColumn, this=old_column, to=new_column, exists=exists)
+
self._match_text_seq("TO")
return self.expression(exp.RenameTable, this=self._parse_table(schema=True))
@@ -5319,7 +5399,7 @@ class Parser(metaclass=_Parser):
if parser:
actions = ensure_list(parser(self))
- if not self._curr:
+ if not self._curr and actions:
return self.expression(
exp.AlterTable,
this=this,
@@ -5467,6 +5547,7 @@ class Parser(metaclass=_Parser):
self._advance()
text = self._find_sql(start, self._prev)
size = len(start.text)
+ self._warn_unsupported()
return exp.Command(this=text[:size], expression=text[size:])
def _parse_dict_property(self, this: str) -> exp.DictProperty:
@@ -5634,7 +5715,7 @@ class Parser(metaclass=_Parser):
if advance:
self._advance()
return True
- return False
+ return None
def _match_text_seq(self, *texts, advance=True):
index = self._index
@@ -5643,7 +5724,7 @@ class Parser(metaclass=_Parser):
self._advance()
else:
self._retreat(index)
- return False
+ return None
if not advance:
self._retreat(index)
@@ -5651,14 +5732,12 @@ class Parser(metaclass=_Parser):
return True
@t.overload
- def _replace_columns_with_dots(self, this: exp.Expression) -> exp.Expression:
- ...
+ def _replace_columns_with_dots(self, this: exp.Expression) -> exp.Expression: ...
@t.overload
def _replace_columns_with_dots(
self, this: t.Optional[exp.Expression]
- ) -> t.Optional[exp.Expression]:
- ...
+ ) -> t.Optional[exp.Expression]: ...
def _replace_columns_with_dots(self, this):
if isinstance(this, exp.Dot):