From ebec59cc5cb6c6856705bf82ced7fe8d9f75b0d0 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Tue, 7 Mar 2023 19:09:31 +0100 Subject: Merging upstream version 11.3.0. Signed-off-by: Daniel Baumann --- sqlglot/expressions.py | 38 +++++++++++++++++++++++++++++++++----- 1 file changed, 33 insertions(+), 5 deletions(-) (limited to 'sqlglot/expressions.py') diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 59881d6..00a3b45 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -35,6 +35,8 @@ from sqlglot.tokens import Token if t.TYPE_CHECKING: from sqlglot.dialects.dialect import DialectType +E = t.TypeVar("E", bound="Expression") + class _Expression(type): def __new__(cls, clsname, bases, attrs): @@ -293,7 +295,7 @@ class Expression(metaclass=_Expression): return self.parent.depth + 1 return 0 - def find(self, *expression_types, bfs=True): + def find(self, *expression_types: t.Type[E], bfs=True) -> E | None: """ Returns the first node in this tree which matches at least one of the specified types. @@ -306,7 +308,7 @@ class Expression(metaclass=_Expression): """ return next(self.find_all(*expression_types, bfs=bfs), None) - def find_all(self, *expression_types, bfs=True): + def find_all(self, *expression_types: t.Type[E], bfs=True) -> t.Iterator[E]: """ Returns a generator object which visits all nodes in this tree and only yields those that match at least one of the specified expression types. @@ -321,7 +323,7 @@ class Expression(metaclass=_Expression): if isinstance(expression, expression_types): yield expression - def find_ancestor(self, *expression_types): + def find_ancestor(self, *expression_types: t.Type[E]) -> E | None: """ Returns a nearest parent matching expression_types. @@ -334,7 +336,8 @@ class Expression(metaclass=_Expression): ancestor = self.parent while ancestor and not isinstance(ancestor, expression_types): ancestor = ancestor.parent - return ancestor + # ignore type because mypy doesn't know that we're checking type in the loop + return ancestor # type: ignore[return-value] @property def parent_select(self): @@ -794,6 +797,7 @@ class Create(Expression): "properties": False, "replace": False, "unique": False, + "volatile": False, "indexes": False, "no_schema_binding": False, "begin": False, @@ -883,7 +887,7 @@ class ByteString(Condition): class Column(Condition): - arg_types = {"this": True, "table": False, "db": False, "catalog": False} + arg_types = {"this": True, "table": False, "db": False, "catalog": False, "join_mark": False} @property def table(self) -> str: @@ -926,6 +930,14 @@ class RenameTable(Expression): pass +class SetTag(Expression): + arg_types = {"expressions": True, "unset": False} + + +class Comment(Expression): + arg_types = {"this": True, "kind": True, "expression": True, "exists": False} + + class ColumnConstraint(Expression): arg_types = {"this": False, "kind": True} @@ -2829,6 +2841,14 @@ class Div(Binary): pass +class FloatDiv(Binary): + pass + + +class Overlaps(Binary): + pass + + class Dot(Binary): @property def name(self) -> str: @@ -3125,6 +3145,10 @@ class ArrayFilter(Func): _sql_names = ["FILTER", "ARRAY_FILTER"] +class ArrayJoin(Func): + arg_types = {"this": True, "expression": True, "null": False} + + class ArraySize(Func): arg_types = {"this": True, "expression": False} @@ -3510,6 +3534,10 @@ class ApproxQuantile(Quantile): arg_types = {"this": True, "quantile": True, "accuracy": False, "weight": False} +class RangeN(Func): + arg_types = {"this": True, "expressions": True, "each": False} + + class ReadCSV(Func): _sql_names = ["READ_CSV"] is_var_len_args = True -- cgit v1.2.3