summaryrefslogtreecommitdiffstats
path: root/sqlglot/parser.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/parser.py')
-rw-r--r--sqlglot/parser.py121
1 files changed, 78 insertions, 43 deletions
diff --git a/sqlglot/parser.py b/sqlglot/parser.py
index 5f20afc..c29e520 100644
--- a/sqlglot/parser.py
+++ b/sqlglot/parser.py
@@ -8,6 +8,18 @@ from sqlglot.tokens import Token, Tokenizer, TokenType
logger = logging.getLogger("sqlglot")
+def parse_var_map(args):
+ keys = []
+ values = []
+ for i in range(0, len(args), 2):
+ keys.append(args[i])
+ values.append(args[i + 1])
+ return exp.VarMap(
+ keys=exp.Array(expressions=keys),
+ values=exp.Array(expressions=values),
+ )
+
+
class Parser:
"""
Parser consumes a list of tokens produced by the :class:`~sqlglot.tokens.Tokenizer`
@@ -48,6 +60,7 @@ class Parser:
start=exp.Literal.number(1),
length=exp.Literal.number(10),
),
+ "VAR_MAP": parse_var_map,
}
NO_PAREN_FUNCTIONS = {
@@ -117,6 +130,7 @@ class Parser:
TokenType.VAR,
TokenType.ALTER,
TokenType.ALWAYS,
+ TokenType.ANTI,
TokenType.BEGIN,
TokenType.BOTH,
TokenType.BUCKET,
@@ -164,6 +178,7 @@ class Parser:
TokenType.ROWS,
TokenType.SCHEMA_COMMENT,
TokenType.SEED,
+ TokenType.SEMI,
TokenType.SET,
TokenType.SHOW,
TokenType.STABLE,
@@ -273,6 +288,8 @@ class Parser:
TokenType.INNER,
TokenType.OUTER,
TokenType.CROSS,
+ TokenType.SEMI,
+ TokenType.ANTI,
}
COLUMN_OPERATORS = {
@@ -318,6 +335,8 @@ class Parser:
exp.Properties: lambda self: self._parse_properties(),
exp.Where: lambda self: self._parse_where(),
exp.Ordered: lambda self: self._parse_ordered(),
+ exp.Having: lambda self: self._parse_having(),
+ exp.With: lambda self: self._parse_with(),
"JOIN_TYPE": lambda self: self._parse_join_side_and_kind(),
}
@@ -338,7 +357,6 @@ class Parser:
TokenType.NULL: lambda *_: exp.Null(),
TokenType.TRUE: lambda *_: exp.Boolean(this=True),
TokenType.FALSE: lambda *_: exp.Boolean(this=False),
- TokenType.PLACEHOLDER: lambda *_: exp.Placeholder(),
TokenType.PARAMETER: lambda self, _: exp.Parameter(this=self._parse_var() or self._parse_primary()),
TokenType.BIT_STRING: lambda _, token: exp.BitString(this=token.text),
TokenType.HEX_STRING: lambda _, token: exp.HexString(this=token.text),
@@ -910,7 +928,20 @@ class Parser:
return self.expression(exp.Tuple, expressions=expressions)
def _parse_select(self, nested=False, table=False):
- if self._match(TokenType.SELECT):
+ cte = self._parse_with()
+ if cte:
+ this = self._parse_statement()
+
+ if not this:
+ self.raise_error("Failed to parse any statement following CTE")
+ return cte
+
+ if "with" in this.arg_types:
+ this.set("with", cte)
+ else:
+ self.raise_error(f"{this.key} does not support CTE")
+ this = cte
+ elif self._match(TokenType.SELECT):
hint = self._parse_hint()
all_ = self._match(TokenType.ALL)
distinct = self._match(TokenType.DISTINCT)
@@ -938,39 +969,6 @@ class Parser:
if from_:
this.set("from", from_)
self._parse_query_modifiers(this)
- elif self._match(TokenType.WITH):
- recursive = self._match(TokenType.RECURSIVE)
-
- expressions = []
-
- while True:
- expressions.append(self._parse_cte())
-
- if not self._match(TokenType.COMMA):
- break
-
- cte = self.expression(
- exp.With,
- expressions=expressions,
- recursive=recursive,
- )
- this = self._parse_statement()
-
- if not this:
- self.raise_error("Failed to parse any statement following CTE")
- return cte
-
- if "with" in this.arg_types:
- this.set(
- "with",
- self.expression(
- exp.With,
- expressions=expressions,
- recursive=recursive,
- ),
- )
- else:
- self.raise_error(f"{this.key} does not support CTE")
elif (table or nested) and self._match(TokenType.L_PAREN):
this = self._parse_table() if table else self._parse_select(nested=True)
self._parse_query_modifiers(this)
@@ -986,6 +984,26 @@ class Parser:
return self._parse_set_operations(this) if this else None
+ def _parse_with(self):
+ if not self._match(TokenType.WITH):
+ return None
+
+ recursive = self._match(TokenType.RECURSIVE)
+
+ expressions = []
+
+ while True:
+ expressions.append(self._parse_cte())
+
+ if not self._match(TokenType.COMMA):
+ break
+
+ return self.expression(
+ exp.With,
+ expressions=expressions,
+ recursive=recursive,
+ )
+
def _parse_cte(self):
alias = self._parse_table_alias()
if not alias or not alias.this:
@@ -1485,8 +1503,7 @@ class Parser:
unnest = self._parse_unnest()
if unnest:
this = self.expression(exp.In, this=this, unnest=unnest)
- else:
- self._match_l_paren()
+ elif self._match(TokenType.L_PAREN):
expressions = self._parse_csv(self._parse_select_or_expression)
if len(expressions) == 1 and isinstance(expressions[0], exp.Subqueryable):
@@ -1495,6 +1512,9 @@ class Parser:
this = self.expression(exp.In, this=this, expressions=expressions)
self._match_r_paren()
+ else:
+ this = self.expression(exp.In, this=this, field=self._parse_field())
+
return this
def _parse_between(self, this):
@@ -1591,7 +1611,7 @@ class Parser:
elif nested:
expressions = self._parse_csv(self._parse_types)
else:
- expressions = self._parse_csv(self._parse_number)
+ expressions = self._parse_csv(self._parse_type)
if not expressions:
self._retreat(index)
@@ -1706,7 +1726,7 @@ class Parser:
def _parse_field(self, any_token=False):
return self._parse_primary() or self._parse_function() or self._parse_id_var(any_token)
- def _parse_function(self):
+ def _parse_function(self, functions=None):
if not self._curr:
return None
@@ -1742,7 +1762,9 @@ class Parser:
self._match_r_paren()
return this
- function = self.FUNCTIONS.get(upper)
+ if functions is None:
+ functions = self.FUNCTIONS
+ function = functions.get(upper)
args = self._parse_csv(self._parse_lambda)
if function:
@@ -2025,10 +2047,20 @@ class Parser:
return self.expression(exp.Cast, this=this, to=to)
def _parse_position(self):
- substr = self._parse_bitwise()
+ args = self._parse_csv(self._parse_bitwise)
+
if self._match(TokenType.IN):
- string = self._parse_bitwise()
- return self.expression(exp.StrPosition, this=string, substr=substr)
+ args.append(self._parse_bitwise())
+
+ # Note: we're parsing in order needle, haystack, position
+ this = exp.StrPosition.from_arg_list(args)
+ self.validate_expression(this, args)
+
+ return this
+
+ def _parse_join_hint(self, func_name):
+ args = self._parse_csv(self._parse_table)
+ return exp.JoinHint(this=func_name.upper(), expressions=args)
def _parse_substring(self):
# Postgres supports the form: substring(string [from int] [for int])
@@ -2247,6 +2279,9 @@ class Parser:
def _parse_placeholder(self):
if self._match(TokenType.PLACEHOLDER):
return exp.Placeholder()
+ elif self._match(TokenType.COLON):
+ self._advance()
+ return exp.Placeholder(this=self._prev.text)
return None
def _parse_except(self):