summaryrefslogtreecommitdiffstats
path: root/sqlglot/parser.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/parser.py')
-rw-r--r--sqlglot/parser.py96
1 files changed, 94 insertions, 2 deletions
diff --git a/sqlglot/parser.py b/sqlglot/parser.py
index 8269525..b3b899c 100644
--- a/sqlglot/parser.py
+++ b/sqlglot/parser.py
@@ -105,6 +105,7 @@ class Parser(metaclass=_Parser):
TokenType.CURRENT_DATETIME: exp.CurrentDate,
TokenType.CURRENT_TIME: exp.CurrentTime,
TokenType.CURRENT_TIMESTAMP: exp.CurrentTimestamp,
+ TokenType.CURRENT_USER: exp.CurrentUser,
}
NESTED_TYPE_TOKENS = {
@@ -285,6 +286,7 @@ class Parser(metaclass=_Parser):
TokenType.CURRENT_DATETIME,
TokenType.CURRENT_TIMESTAMP,
TokenType.CURRENT_TIME,
+ TokenType.CURRENT_USER,
TokenType.FILTER,
TokenType.FIRST,
TokenType.FORMAT,
@@ -674,9 +676,11 @@ class Parser(metaclass=_Parser):
FUNCTION_PARSERS: t.Dict[str, t.Callable] = {
"CAST": lambda self: self._parse_cast(self.STRICT_CAST),
"CONVERT": lambda self: self._parse_convert(self.STRICT_CAST),
+ "DECODE": lambda self: self._parse_decode(),
"EXTRACT": lambda self: self._parse_extract(),
"JSON_OBJECT": lambda self: self._parse_json_object(),
"LOG": lambda self: self._parse_logarithm(),
+ "MATCH": lambda self: self._parse_match_against(),
"POSITION": lambda self: self._parse_position(),
"STRING_AGG": lambda self: self._parse_string_agg(),
"SUBSTRING": lambda self: self._parse_substring(),
@@ -2634,7 +2638,7 @@ class Parser(metaclass=_Parser):
self._match_r_paren()
maybe_func = True
- if not nested and self._match_pair(TokenType.L_BRACKET, TokenType.R_BRACKET):
+ if self._match_pair(TokenType.L_BRACKET, TokenType.R_BRACKET):
this = exp.DataType(
this=exp.DataType.Type.ARRAY,
expressions=[exp.DataType.build(type_token.value, expressions=expressions)],
@@ -2959,6 +2963,11 @@ class Parser(metaclass=_Parser):
else:
this = self._parse_select_or_expression()
+ if isinstance(this, exp.EQ):
+ left = this.this
+ if isinstance(left, exp.Column):
+ left.replace(exp.Var(this=left.text("this")))
+
if self._match(TokenType.IGNORE_NULLS):
this = self.expression(exp.IgnoreNulls, this=this)
else:
@@ -2968,8 +2977,16 @@ class Parser(metaclass=_Parser):
def _parse_schema(self, this: t.Optional[exp.Expression] = None) -> t.Optional[exp.Expression]:
index = self._index
- if not self._match(TokenType.L_PAREN) or self._match(TokenType.SELECT):
+
+ try:
+ if self._parse_select(nested=True):
+ return this
+ except Exception:
+ pass
+ finally:
self._retreat(index)
+
+ if not self._match(TokenType.L_PAREN):
return this
args = self._parse_csv(
@@ -3344,6 +3361,51 @@ class Parser(metaclass=_Parser):
return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to)
+ def _parse_decode(self) -> t.Optional[exp.Expression]:
+ """
+ There are generally two variants of the DECODE function:
+
+ - DECODE(bin, charset)
+ - DECODE(expression, search, result [, search, result] ... [, default])
+
+ The second variant will always be parsed into a CASE expression. Note that NULL
+ needs special treatment, since we need to explicitly check for it with `IS NULL`,
+ instead of relying on pattern matching.
+ """
+ args = self._parse_csv(self._parse_conjunction)
+
+ if len(args) < 3:
+ return self.expression(exp.Decode, this=seq_get(args, 0), charset=seq_get(args, 1))
+
+ expression, *expressions = args
+ if not expression:
+ return None
+
+ ifs = []
+ for search, result in zip(expressions[::2], expressions[1::2]):
+ if not search or not result:
+ return None
+
+ if isinstance(search, exp.Literal):
+ ifs.append(
+ exp.If(this=exp.EQ(this=expression.copy(), expression=search), true=result)
+ )
+ elif isinstance(search, exp.Null):
+ ifs.append(
+ exp.If(this=exp.Is(this=expression.copy(), expression=exp.Null()), true=result)
+ )
+ else:
+ cond = exp.or_(
+ exp.EQ(this=expression.copy(), expression=search),
+ exp.and_(
+ exp.Is(this=expression.copy(), expression=exp.Null()),
+ exp.Is(this=search.copy(), expression=exp.Null()),
+ ),
+ )
+ ifs.append(exp.If(this=cond, true=result))
+
+ return exp.Case(ifs=ifs, default=expressions[-1] if len(expressions) % 2 == 1 else None)
+
def _parse_json_key_value(self) -> t.Optional[exp.Expression]:
self._match_text_seq("KEY")
key = self._parse_field()
@@ -3398,6 +3460,28 @@ class Parser(metaclass=_Parser):
exp.Ln if self.LOG_DEFAULTS_TO_LN else exp.Log, this=seq_get(args, 0)
)
+ def _parse_match_against(self) -> exp.Expression:
+ expressions = self._parse_csv(self._parse_column)
+
+ self._match_text_seq(")", "AGAINST", "(")
+
+ this = self._parse_string()
+
+ if self._match_text_seq("IN", "NATURAL", "LANGUAGE", "MODE"):
+ modifier = "IN NATURAL LANGUAGE MODE"
+ if self._match_text_seq("WITH", "QUERY", "EXPANSION"):
+ modifier = f"{modifier} WITH QUERY EXPANSION"
+ elif self._match_text_seq("IN", "BOOLEAN", "MODE"):
+ modifier = "IN BOOLEAN MODE"
+ elif self._match_text_seq("WITH", "QUERY", "EXPANSION"):
+ modifier = "WITH QUERY EXPANSION"
+ else:
+ modifier = None
+
+ return self.expression(
+ exp.MatchAgainst, this=this, expressions=expressions, modifier=modifier
+ )
+
def _parse_position(self, haystack_first: bool = False) -> exp.Expression:
args = self._parse_csv(self._parse_bitwise)
@@ -3791,6 +3875,14 @@ class Parser(metaclass=_Parser):
if expression:
expression.set("exists", exists_column)
+ # https://docs.databricks.com/delta/update-schema.html#explicitly-update-schema-to-add-columns
+ if self._match_texts(("FIRST", "AFTER")):
+ position = self._prev.text
+ column_position = self.expression(
+ exp.ColumnPosition, this=self._parse_column(), position=position
+ )
+ expression.set("position", column_position)
+
return expression
def _parse_drop_column(self) -> t.Optional[exp.Expression]: