From 4554ab4c7d6b2bbbaa6f4d0b810bf477d1a505a6 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Mon, 25 Sep 2023 10:20:09 +0200 Subject: Merging upstream version 18.7.0. Signed-off-by: Daniel Baumann --- sqlglot/expressions.py | 68 ++++++++++++++++++++++++++++++++++---------------- 1 file changed, 46 insertions(+), 22 deletions(-) (limited to 'sqlglot/expressions.py') diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 1c3d42a..8e9575e 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -664,16 +664,6 @@ class Expression(metaclass=_Expression): return load(obj) - -IntoType = t.Union[ - str, - t.Type[Expression], - t.Collection[t.Union[str, t.Type[Expression]]], -] -ExpOrStr = t.Union[str, Expression] - - -class Condition(Expression): def and_( self, *expressions: t.Optional[ExpOrStr], @@ -762,11 +752,19 @@ class Condition(Expression): return klass(this=other, expression=this) return klass(this=this, expression=other) - def __getitem__(self, other: ExpOrStr | t.Tuple[ExpOrStr]): + def __getitem__(self, other: ExpOrStr | t.Tuple[ExpOrStr]) -> Bracket: return Bracket( this=self.copy(), expressions=[convert(e, copy=True) for e in ensure_list(other)] ) + def __iter__(self) -> t.Iterator: + if "expressions" in self.arg_types: + return iter(self.args.get("expressions") or []) + # We define this because __getitem__ converts Expression into an iterable, which is + # problematic because one can hit infinite loops if they do "for x in some_expr: ..." + # See: https://peps.python.org/pep-0234/ + raise TypeError(f"'{self.__class__.__name__}' object is not iterable") + def isin( self, *expressions: t.Any, @@ -886,6 +884,18 @@ class Condition(Expression): return not_(self.copy()) +IntoType = t.Union[ + str, + t.Type[Expression], + t.Collection[t.Union[str, t.Type[Expression]]], +] +ExpOrStr = t.Union[str, Expression] + + +class Condition(Expression): + """Logical conditions like x AND y, or simply x""" + + class Predicate(Condition): """Relationships like x = y, x > 1, x >= y.""" @@ -1045,6 +1055,10 @@ class Describe(Expression): arg_types = {"this": True, "kind": False, "expressions": False} +class Kill(Expression): + arg_types = {"this": True, "kind": False} + + class Pragma(Expression): pass @@ -1161,7 +1175,7 @@ class Column(Condition): if self.args.get(part) ] - def to_dot(self) -> Dot: + def to_dot(self) -> Dot | Identifier: """Converts the column into a dot expression.""" parts = self.parts parent = self.parent @@ -1171,7 +1185,7 @@ class Column(Condition): parts.append(parent.expression) parent = parent.parent - return Dot.build(deepcopy(parts)) + return Dot.build(deepcopy(parts)) if len(parts) > 1 else parts[0] class ColumnPosition(Expression): @@ -1607,6 +1621,7 @@ class Index(Expression): "primary": False, "amp": False, # teradata "partition_by": False, # teradata + "where": False, # postgres partial indexes } @@ -1917,7 +1932,7 @@ class Sort(Order): class Ordered(Expression): - arg_types = {"this": True, "desc": True, "nulls_first": True} + arg_types = {"this": True, "desc": False, "nulls_first": True} class Property(Expression): @@ -2569,7 +2584,6 @@ class Intersect(Union): class Unnest(UDTF): arg_types = { "expressions": True, - "ordinality": False, "alias": False, "offset": False, } @@ -2862,6 +2876,7 @@ class Select(Subqueryable): prefix="LIMIT", dialect=dialect, copy=copy, + into_arg="expression", **opts, ) @@ -4007,6 +4022,10 @@ class TimeUnit(Expression): super().__init__(**args) + @property + def unit(self) -> t.Optional[Var]: + return self.args.get("unit") + # https://www.oracletutorial.com/oracle-basics/oracle-interval/ # https://trino.io/docs/current/language/types.html#interval-day-to-second @@ -4018,10 +4037,6 @@ class IntervalSpan(Expression): class Interval(TimeUnit): arg_types = {"this": False, "unit": False} - @property - def unit(self) -> t.Optional[Var]: - return self.args.get("unit") - class IgnoreNulls(Expression): pass @@ -4327,6 +4342,10 @@ class DateDiff(Func, TimeUnit): class DateTrunc(Func): arg_types = {"unit": True, "this": True, "zone": False} + @property + def unit(self) -> Expression: + return self.args["unit"] + class DatetimeAdd(Func, TimeUnit): arg_types = {"this": True, "expression": True, "unit": False} @@ -4427,7 +4446,8 @@ class DateToDi(Func): # https://cloud.google.com/bigquery/docs/reference/standard-sql/date_functions#date class Date(Func): - arg_types = {"this": True, "zone": False} + arg_types = {"this": False, "zone": False, "expressions": False} + is_var_len_args = True class Day(Func): @@ -5131,10 +5151,11 @@ def _apply_builder( prefix=None, into=None, dialect=None, + into_arg="this", **opts, ): if _is_wrong_expression(expression, into): - expression = into(this=expression) + expression = into(**{into_arg: expression}) instance = maybe_copy(instance, copy) expression = maybe_parse( sql_or_expression=expression, @@ -5926,7 +5947,10 @@ def cast(expression: ExpOrStr, to: str | DataType | DataType.Type, **opts) -> Ca The new Cast instance. """ expression = maybe_parse(expression, **opts) - return Cast(this=expression, to=DataType.build(to, **opts)) + data_type = DataType.build(to, **opts) + expression = Cast(this=expression, to=data_type) + expression.type = data_type + return expression def table_( -- cgit v1.2.3