summaryrefslogtreecommitdiffstats
path: root/sqlglot/parser.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/parser.py')
-rw-r--r--sqlglot/parser.py309
1 files changed, 231 insertions, 78 deletions
diff --git a/sqlglot/parser.py b/sqlglot/parser.py
index 67ffd8f..c2cb3a1 100644
--- a/sqlglot/parser.py
+++ b/sqlglot/parser.py
@@ -41,11 +41,17 @@ def build_like(args: t.List) -> exp.Escape | exp.Like:
def binary_range_parser(
- expr_type: t.Type[exp.Expression],
+ expr_type: t.Type[exp.Expression], reverse_args: bool = False
) -> t.Callable[[Parser, t.Optional[exp.Expression]], t.Optional[exp.Expression]]:
- return lambda self, this: self._parse_escape(
- self.expression(expr_type, this=this, expression=self._parse_bitwise())
- )
+ def _parse_binary_range(
+ self: Parser, this: t.Optional[exp.Expression]
+ ) -> t.Optional[exp.Expression]:
+ expression = self._parse_bitwise()
+ if reverse_args:
+ this, expression = expression, this
+ return self._parse_escape(self.expression(expr_type, this=this, expression=expression))
+
+ return _parse_binary_range
def build_logarithm(args: t.List, dialect: Dialect) -> exp.Func:
@@ -335,6 +341,8 @@ class Parser(metaclass=_Parser):
TokenType.TABLE,
TokenType.TAG,
TokenType.VIEW,
+ TokenType.WAREHOUSE,
+ TokenType.STREAMLIT,
}
CREATABLES = {
@@ -418,6 +426,7 @@ class Parser(metaclass=_Parser):
TokenType.TRUE,
TokenType.TRUNCATE,
TokenType.UNIQUE,
+ TokenType.UNNEST,
TokenType.UNPIVOT,
TokenType.UPDATE,
TokenType.USE,
@@ -580,7 +589,7 @@ class Parser(metaclass=_Parser):
exp.Lambda,
this=self._replace_lambda(
self._parse_conjunction(),
- {node.name for node in expressions},
+ expressions,
),
expressions=expressions,
),
@@ -1125,6 +1134,8 @@ class Parser(metaclass=_Parser):
SELECT_START_TOKENS = {TokenType.L_PAREN, TokenType.WITH, TokenType.SELECT}
+ COPY_INTO_VARLEN_OPTIONS = {"FILE_FORMAT", "COPY_OPTIONS", "FORMAT_OPTIONS", "CREDENTIAL"}
+
STRICT_CAST = True
PREFIXED_PIVOT_COLUMNS = False
@@ -1160,6 +1171,9 @@ class Parser(metaclass=_Parser):
# Whether the -> and ->> operators expect documents of type JSON (e.g. Postgres)
JSON_ARROWS_REQUIRE_JSON_TYPE = False
+ # Whether the `:` operator is used to extract a value from a JSON document
+ COLON_IS_JSON_EXTRACT = False
+
# Whether or not a VALUES keyword needs to be followed by '(' to form a VALUES clause.
# If this is True and '(' is not found, the keyword will be treated as an identifier
VALUES_FOLLOWED_BY_PAREN = True
@@ -1631,6 +1645,7 @@ class Parser(metaclass=_Parser):
extend_props(self._parse_properties())
expression = self._match(TokenType.ALIAS) and self._parse_heredoc()
+ extend_props(self._parse_properties())
if not expression:
if self._match(TokenType.COMMAND):
@@ -1817,11 +1832,17 @@ class Parser(metaclass=_Parser):
self._retreat(index)
return self._parse_sequence_properties()
- return self.expression(
- exp.Property,
- this=key.to_dot() if isinstance(key, exp.Column) else key,
- value=self._parse_bitwise() or self._parse_var(any_token=True),
- )
+ # Transform the key to exp.Dot if it's dotted identifiers wrapped in exp.Column or to exp.Var otherwise
+ if isinstance(key, exp.Column):
+ key = key.to_dot() if len(key.parts) > 1 else exp.var(key.name)
+
+ value = self._parse_bitwise() or self._parse_var(any_token=True)
+
+ # Transform the value to exp.Var if it was parsed as exp.Column(exp.Identifier())
+ if isinstance(value, exp.Column):
+ value = exp.var(value.name)
+
+ return self.expression(exp.Property, this=key, value=value)
def _parse_stored(self) -> exp.FileFormatProperty:
self._match(TokenType.ALIAS)
@@ -1840,7 +1861,7 @@ class Parser(metaclass=_Parser):
),
)
- def _parse_unquoted_field(self):
+ def _parse_unquoted_field(self) -> t.Optional[exp.Expression]:
field = self._parse_field()
if isinstance(field, exp.Identifier) and not field.quoted:
field = exp.var(field)
@@ -2780,7 +2801,13 @@ class Parser(metaclass=_Parser):
if not alias and not columns:
return None
- return self.expression(exp.TableAlias, this=alias, columns=columns)
+ table_alias = self.expression(exp.TableAlias, this=alias, columns=columns)
+
+ # We bubble up comments from the Identifier to the TableAlias
+ if isinstance(alias, exp.Identifier):
+ table_alias.add_comments(alias.pop_comments())
+
+ return table_alias
def _parse_subquery(
self, this: t.Optional[exp.Expression], parse_alias: bool = True
@@ -4047,7 +4074,7 @@ class Parser(metaclass=_Parser):
return this
return self.expression(exp.Escape, this=this, expression=self._parse_string())
- def _parse_interval(self, match_interval: bool = True) -> t.Optional[exp.Interval]:
+ def _parse_interval(self, match_interval: bool = True) -> t.Optional[exp.Add | exp.Interval]:
index = self._index
if not self._match(TokenType.INTERVAL) and match_interval:
@@ -4077,23 +4104,33 @@ class Parser(metaclass=_Parser):
if this and this.is_number:
this = exp.Literal.string(this.name)
elif this and this.is_string:
- parts = this.name.split()
-
- if len(parts) == 2:
+ parts = exp.INTERVAL_STRING_RE.findall(this.name)
+ if len(parts) == 1:
if unit:
- # This is not actually a unit, it's something else (e.g. a "window side")
- unit = None
+ # Unconsume the eagerly-parsed unit, since the real unit was part of the string
self._retreat(self._index - 1)
- this = exp.Literal.string(parts[0])
- unit = self.expression(exp.Var, this=parts[1].upper())
+ this = exp.Literal.string(parts[0][0])
+ unit = self.expression(exp.Var, this=parts[0][1].upper())
if self.INTERVAL_SPANS and self._match_text_seq("TO"):
unit = self.expression(
exp.IntervalSpan, this=unit, expression=self._parse_var(any_token=True, upper=True)
)
- return self.expression(exp.Interval, this=this, unit=unit)
+ interval = self.expression(exp.Interval, this=this, unit=unit)
+
+ index = self._index
+ self._match(TokenType.PLUS)
+
+ # Convert INTERVAL 'val_1' unit_1 [+] ... [+] 'val_n' unit_n into a sum of intervals
+ if self._match_set((TokenType.STRING, TokenType.NUMBER), advance=False):
+ return self.expression(
+ exp.Add, this=interval, expression=self._parse_interval(match_interval=False)
+ )
+
+ self._retreat(index)
+ return interval
def _parse_bitwise(self) -> t.Optional[exp.Expression]:
this = self._parse_term()
@@ -4155,39 +4192,50 @@ class Parser(metaclass=_Parser):
return self.UNARY_PARSERS[self._prev.token_type](self)
return self._parse_at_time_zone(self._parse_type())
- def _parse_type(self, parse_interval: bool = True) -> t.Optional[exp.Expression]:
+ def _parse_type(
+ self, parse_interval: bool = True, fallback_to_identifier: bool = False
+ ) -> t.Optional[exp.Expression]:
interval = parse_interval and self._parse_interval()
if interval:
- # Convert INTERVAL 'val_1' unit_1 [+] ... [+] 'val_n' unit_n into a sum of intervals
- while True:
- index = self._index
- self._match(TokenType.PLUS)
-
- if not self._match_set((TokenType.STRING, TokenType.NUMBER), advance=False):
- self._retreat(index)
- break
-
- interval = self.expression( # type: ignore
- exp.Add, this=interval, expression=self._parse_interval(match_interval=False)
- )
-
return interval
index = self._index
data_type = self._parse_types(check_func=True, allow_identifiers=False)
- this = self._parse_column()
if data_type:
+ index2 = self._index
+ this = self._parse_primary()
+
if isinstance(this, exp.Literal):
parser = self.TYPE_LITERAL_PARSERS.get(data_type.this)
if parser:
return parser(self, this, data_type)
+
return self.expression(exp.Cast, this=this, to=data_type)
- if not data_type.expressions:
- self._retreat(index)
- return self._parse_column()
- return self._parse_column_ops(data_type)
+ # The expressions arg gets set by the parser when we have something like DECIMAL(38, 0)
+ # in the input SQL. In that case, we'll produce these tokens: DECIMAL ( 38 , 0 )
+ #
+ # If the index difference here is greater than 1, that means the parser itself must have
+ # consumed additional tokens such as the DECIMAL scale and precision in the above example.
+ #
+ # If it's not greater than 1, then it must be 1, because we've consumed at least the type
+ # keyword, meaning that the expressions arg of the DataType must have gotten set by a
+ # callable in the TYPE_CONVERTERS mapping. For example, Snowflake converts DECIMAL to
+ # DECIMAL(38, 0)) in order to facilitate the data type's transpilation.
+ #
+ # In these cases, we don't really want to return the converted type, but instead retreat
+ # and try to parse a Column or Identifier in the section below.
+ if data_type.expressions and index2 - index > 1:
+ self._retreat(index2)
+ return self._parse_column_ops(data_type)
+
+ self._retreat(index)
+
+ if fallback_to_identifier:
+ return self._parse_id_var()
+
+ this = self._parse_column()
return this and self._parse_column_ops(this)
def _parse_type_size(self) -> t.Optional[exp.DataTypeParam]:
@@ -4251,7 +4299,7 @@ class Parser(metaclass=_Parser):
if self._match(TokenType.L_PAREN):
if is_struct:
- expressions = self._parse_csv(self._parse_struct_types)
+ expressions = self._parse_csv(lambda: self._parse_struct_types(type_required=True))
elif nested:
expressions = self._parse_csv(
lambda: self._parse_types(
@@ -4352,8 +4400,26 @@ class Parser(metaclass=_Parser):
elif expressions:
this.set("expressions", expressions)
- while self._match_pair(TokenType.L_BRACKET, TokenType.R_BRACKET):
- this = exp.DataType(this=exp.DataType.Type.ARRAY, expressions=[this], nested=True)
+ index = self._index
+
+ # Postgres supports the INT ARRAY[3] syntax as a synonym for INT[3]
+ matched_array = self._match(TokenType.ARRAY)
+
+ while self._curr:
+ matched_l_bracket = self._match(TokenType.L_BRACKET)
+ if not matched_l_bracket and not matched_array:
+ break
+
+ matched_array = False
+ values = self._parse_csv(self._parse_conjunction) or None
+ if values and not schema:
+ self._retreat(index)
+ break
+
+ this = exp.DataType(
+ this=exp.DataType.Type.ARRAY, expressions=[this], values=values, nested=True
+ )
+ self._match(TokenType.R_BRACKET)
if self.TYPE_CONVERTER and isinstance(this.this, exp.DataType.Type):
converter = self.TYPE_CONVERTER.get(this.this)
@@ -4364,17 +4430,21 @@ class Parser(metaclass=_Parser):
def _parse_struct_types(self, type_required: bool = False) -> t.Optional[exp.Expression]:
index = self._index
- this = self._parse_type(parse_interval=False) or self._parse_id_var()
+ this = (
+ self._parse_type(parse_interval=False, fallback_to_identifier=True)
+ or self._parse_id_var()
+ )
self._match(TokenType.COLON)
- column_def = self._parse_column_def(this)
- if type_required and (
- (isinstance(this, exp.Column) and this.this is column_def) or this is column_def
+ if (
+ type_required
+ and not isinstance(this, exp.DataType)
+ and not self._match_set(self.TYPE_TOKENS, advance=False)
):
self._retreat(index)
return self._parse_types()
- return column_def
+ return self._parse_column_def(this)
def _parse_at_time_zone(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
if not self._match_text_seq("AT", "TIME", "ZONE"):
@@ -4401,6 +4471,47 @@ class Parser(metaclass=_Parser):
return this
+ def _parse_colon_as_json_extract(
+ self, this: t.Optional[exp.Expression]
+ ) -> t.Optional[exp.Expression]:
+ casts = []
+ json_path = []
+
+ while self._match(TokenType.COLON):
+ start_index = self._index
+ path = self._parse_column_ops(self._parse_field(any_token=True))
+
+ # The cast :: operator has a lower precedence than the extraction operator :, so
+ # we rearrange the AST appropriately to avoid casting the JSON path
+ while isinstance(path, exp.Cast):
+ casts.append(path.to)
+ path = path.this
+
+ if casts:
+ dcolon_offset = next(
+ i
+ for i, t in enumerate(self._tokens[start_index:])
+ if t.token_type == TokenType.DCOLON
+ )
+ end_token = self._tokens[start_index + dcolon_offset - 1]
+ else:
+ end_token = self._prev
+
+ if path:
+ json_path.append(self._find_sql(self._tokens[start_index], end_token))
+
+ if json_path:
+ this = self.expression(
+ exp.JSONExtract,
+ this=this,
+ expression=self.dialect.to_json_path(exp.Literal.string(".".join(json_path))),
+ )
+
+ while casts:
+ this = self.expression(exp.Cast, this=this, to=casts.pop())
+
+ return this
+
def _parse_column_ops(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
this = self._parse_bracket(this)
@@ -4444,8 +4555,10 @@ class Parser(metaclass=_Parser):
)
else:
this = self.expression(exp.Dot, this=this, expression=field)
+
this = self._parse_bracket(this)
- return this
+
+ return self._parse_colon_as_json_extract(this) if self.COLON_IS_JSON_EXTRACT else this
def _parse_primary(self) -> t.Optional[exp.Expression]:
if self._match_set(self.PRIMARY_PARSERS):
@@ -4680,18 +4793,21 @@ class Parser(metaclass=_Parser):
return self.expression(exp.SessionParameter, this=this, kind=kind)
+ def _parse_lambda_arg(self) -> t.Optional[exp.Expression]:
+ return self._parse_id_var()
+
def _parse_lambda(self, alias: bool = False) -> t.Optional[exp.Expression]:
index = self._index
if self._match(TokenType.L_PAREN):
expressions = t.cast(
- t.List[t.Optional[exp.Expression]], self._parse_csv(self._parse_id_var)
+ t.List[t.Optional[exp.Expression]], self._parse_csv(self._parse_lambda_arg)
)
if not self._match(TokenType.R_PAREN):
self._retreat(index)
else:
- expressions = [self._parse_id_var()]
+ expressions = [self._parse_lambda_arg()]
if self._match_set(self.LAMBDAS):
return self.LAMBDAS[self._prev.token_type](self, expressions)
@@ -5964,7 +6080,19 @@ class Parser(metaclass=_Parser):
return self.expression(exp.AlterColumn, this=column, default=self._parse_conjunction())
if self._match(TokenType.COMMENT):
return self.expression(exp.AlterColumn, this=column, comment=self._parse_string())
-
+ if self._match_text_seq("DROP", "NOT", "NULL"):
+ return self.expression(
+ exp.AlterColumn,
+ this=column,
+ drop=True,
+ allow_null=True,
+ )
+ if self._match_text_seq("SET", "NOT", "NULL"):
+ return self.expression(
+ exp.AlterColumn,
+ this=column,
+ allow_null=False,
+ )
self._match_text_seq("SET", "DATA")
self._match_text_seq("TYPE")
return self.expression(
@@ -6182,8 +6310,10 @@ class Parser(metaclass=_Parser):
return None
right = self._parse_statement() or self._parse_id_var()
- this = self.expression(exp.EQ, this=left, expression=right)
+ if isinstance(right, (exp.Column, exp.Identifier)):
+ right = exp.var(right.name)
+ this = self.expression(exp.EQ, this=left, expression=right)
return self.expression(exp.SetItem, this=this, kind=kind)
def _parse_set_transaction(self, global_: bool = False) -> exp.Expression:
@@ -6433,14 +6563,25 @@ class Parser(metaclass=_Parser):
return True
def _replace_lambda(
- self, node: t.Optional[exp.Expression], lambda_variables: t.Set[str]
+ self, node: t.Optional[exp.Expression], expressions: t.List[exp.Expression]
) -> t.Optional[exp.Expression]:
if not node:
return node
+ lambda_types = {e.name: e.args.get("to") or False for e in expressions}
+
for column in node.find_all(exp.Column):
- if column.parts[0].name in lambda_variables:
+ typ = lambda_types.get(column.parts[0].name)
+ if typ is not None:
dot_or_id = column.to_dot() if column.table else column.this
+
+ if typ:
+ dot_or_id = self.expression(
+ exp.Cast,
+ this=dot_or_id,
+ to=typ,
+ )
+
parent = column.parent
while isinstance(parent, exp.Dot):
@@ -6516,12 +6657,23 @@ class Parser(metaclass=_Parser):
return self.expression(exp.WithOperator, this=this, op=op)
def _parse_wrapped_options(self) -> t.List[t.Optional[exp.Expression]]:
- opts = []
self._match(TokenType.EQ)
self._match(TokenType.L_PAREN)
+
+ opts: t.List[t.Optional[exp.Expression]] = []
while self._curr and not self._match(TokenType.R_PAREN):
- opts.append(self._parse_conjunction())
+ if self._match_text_seq("FORMAT_NAME", "="):
+ # The FORMAT_NAME can be set to an identifier for Snowflake and T-SQL,
+ # so we parse it separately to use _parse_field()
+ prop = self.expression(
+ exp.Property, this=exp.var("FORMAT_NAME"), value=self._parse_field()
+ )
+ opts.append(prop)
+ else:
+ opts.append(self._parse_property())
+
self._match(TokenType.COMMA)
+
return opts
def _parse_copy_parameters(self) -> t.List[exp.CopyParameter]:
@@ -6529,37 +6681,38 @@ class Parser(metaclass=_Parser):
options = []
while self._curr and not self._match(TokenType.R_PAREN, advance=False):
- option = self._parse_unquoted_field()
- value = None
+ option = self._parse_var(any_token=True)
+ prev = self._prev.text.upper()
- # Some options are defined as functions with the values as params
- if not isinstance(option, exp.Func):
- prev = self._prev.text.upper()
- # Different dialects might separate options and values by white space, "=" and "AS"
- self._match(TokenType.EQ)
- self._match(TokenType.ALIAS)
+ # Different dialects might separate options and values by white space, "=" and "AS"
+ self._match(TokenType.EQ)
+ self._match(TokenType.ALIAS)
- if prev == "FILE_FORMAT" and self._match(TokenType.L_PAREN):
- # Snowflake FILE_FORMAT case
- value = self._parse_wrapped_options()
- else:
- value = self._parse_unquoted_field()
+ param = self.expression(exp.CopyParameter, this=option)
- param = self.expression(exp.CopyParameter, this=option, expression=value)
- options.append(param)
+ if prev in self.COPY_INTO_VARLEN_OPTIONS and self._match(
+ TokenType.L_PAREN, advance=False
+ ):
+ # Snowflake FILE_FORMAT case, Databricks COPY & FORMAT options
+ param.set("expressions", self._parse_wrapped_options())
+ elif prev == "FILE_FORMAT":
+ # T-SQL's external file format case
+ param.set("expression", self._parse_field())
+ else:
+ param.set("expression", self._parse_unquoted_field())
- if sep:
- self._match(sep)
+ options.append(param)
+ self._match(sep)
return options
def _parse_credentials(self) -> t.Optional[exp.Credentials]:
expr = self.expression(exp.Credentials)
- if self._match_text_seq("STORAGE_INTEGRATION", advance=False):
- expr.set("storage", self._parse_conjunction())
+ if self._match_text_seq("STORAGE_INTEGRATION", "="):
+ expr.set("storage", self._parse_field())
if self._match_text_seq("CREDENTIALS"):
- # Snowflake supports CREDENTIALS = (...), while Redshift CREDENTIALS <string>
+ # Snowflake case: CREDENTIALS = (...), Redshift case: CREDENTIALS <string>
creds = (
self._parse_wrapped_options() if self._match(TokenType.EQ) else self._parse_field()
)
@@ -6582,7 +6735,7 @@ class Parser(metaclass=_Parser):
self._match(TokenType.INTO)
this = (
- self._parse_conjunction()
+ self._parse_select(nested=True, parse_subquery_alias=False)
if self._match(TokenType.L_PAREN, advance=False)
else self._parse_table(schema=True)
)