From 66af5c6fc22f6f11e9ea807b274e011a6f64efb7 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Sun, 19 Mar 2023 11:22:09 +0100 Subject: Merging upstream version 11.4.1. Signed-off-by: Daniel Baumann --- sqlglot/expressions.py | 251 +++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 214 insertions(+), 37 deletions(-) (limited to 'sqlglot/expressions.py') diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 0c345b3..b9da4cc 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -301,7 +301,7 @@ class Expression(metaclass=_Expression): the specified types. Args: - expression_types (type): the expression type(s) to match. + expression_types: the expression type(s) to match. Returns: The node which matches the criteria or None if no such node was found. @@ -314,7 +314,7 @@ class Expression(metaclass=_Expression): yields those that match at least one of the specified expression types. Args: - expression_types (type): the expression type(s) to match. + expression_types: the expression type(s) to match. Returns: The generator object. @@ -328,7 +328,7 @@ class Expression(metaclass=_Expression): Returns a nearest parent matching expression_types. Args: - expression_types (type): the expression type(s) to match. + expression_types: the expression type(s) to match. Returns: The parent node. @@ -336,8 +336,7 @@ class Expression(metaclass=_Expression): ancestor = self.parent while ancestor and not isinstance(ancestor, expression_types): ancestor = ancestor.parent - # ignore type because mypy doesn't know that we're checking type in the loop - return ancestor # type: ignore[return-value] + return t.cast(E, ancestor) @property def parent_select(self): @@ -549,8 +548,12 @@ class Expression(metaclass=_Expression): def pop(self): """ Remove this expression from its AST. + + Returns: + The popped expression. """ self.replace(None) + return self def assert_is(self, type_): """ @@ -626,6 +629,7 @@ IntoType = t.Union[ t.Type[Expression], t.Collection[t.Union[str, t.Type[Expression]]], ] +ExpOrStr = t.Union[str, Expression] class Condition(Expression): @@ -809,7 +813,7 @@ class Describe(Expression): class Set(Expression): - arg_types = {"expressions": True} + arg_types = {"expressions": False} class SetItem(Expression): @@ -905,6 +909,23 @@ class Column(Condition): def output_name(self) -> str: return self.name + @property + def parts(self) -> t.List[Identifier]: + """Return the parts of a column in order catalog, db, table, name.""" + return [part for part in reversed(list(self.args.values())) if part] + + def to_dot(self) -> Dot: + """Converts the column into a dot expression.""" + parts = self.parts + parent = self.parent + + while parent: + if isinstance(parent, Dot): + parts.append(parent.expression) + parent = parent.parent + + return Dot.build(parts) + class ColumnDef(Expression): arg_types = { @@ -1033,6 +1054,113 @@ class Constraint(Expression): class Delete(Expression): arg_types = {"with": False, "this": False, "using": False, "where": False, "returning": False} + def delete( + self, + table: ExpOrStr, + dialect: DialectType = None, + copy: bool = True, + **opts, + ) -> Delete: + """ + Create a DELETE expression or replace the table on an existing DELETE expression. + + Example: + >>> delete("tbl").sql() + 'DELETE FROM tbl' + + Args: + table: the table from which to delete. + dialect: the dialect used to parse the input expression. + copy: if `False`, modify this expression instance in-place. + opts: other options to use to parse the input expressions. + + Returns: + Delete: the modified expression. + """ + return _apply_builder( + expression=table, + instance=self, + arg="this", + dialect=dialect, + into=Table, + copy=copy, + **opts, + ) + + def where( + self, + *expressions: ExpOrStr, + append: bool = True, + dialect: DialectType = None, + copy: bool = True, + **opts, + ) -> Delete: + """ + Append to or set the WHERE expressions. + + Example: + >>> delete("tbl").where("x = 'a' OR x < 'b'").sql() + "DELETE FROM tbl WHERE x = 'a' OR x < 'b'" + + Args: + *expressions: the SQL code strings to parse. + If an `Expression` instance is passed, it will be used as-is. + Multiple expressions are combined with an AND operator. + append: if `True`, AND the new expressions to any existing expression. + Otherwise, this resets the expression. + dialect: the dialect used to parse the input expressions. + copy: if `False`, modify this expression instance in-place. + opts: other options to use to parse the input expressions. + + Returns: + Delete: the modified expression. + """ + return _apply_conjunction_builder( + *expressions, + instance=self, + arg="where", + append=append, + into=Where, + dialect=dialect, + copy=copy, + **opts, + ) + + def returning( + self, + expression: ExpOrStr, + dialect: DialectType = None, + copy: bool = True, + **opts, + ) -> Delete: + """ + Set the RETURNING expression. Not supported by all dialects. + + Example: + >>> delete("tbl").returning("*", dialect="postgres").sql() + 'DELETE FROM tbl RETURNING *' + + Args: + expression: the SQL code strings to parse. + If an `Expression` instance is passed, it will be used as-is. + dialect: the dialect used to parse the input expressions. + copy: if `False`, modify this expression instance in-place. + opts: other options to use to parse the input expressions. + + Returns: + Delete: the modified expression. + """ + return _apply_builder( + expression=expression, + instance=self, + arg="returning", + prefix="RETURNING", + dialect=dialect, + copy=copy, + into=Returning, + **opts, + ) + class Drop(Expression): arg_types = { @@ -1824,7 +1952,7 @@ class Union(Subqueryable): def select( self, - *expressions: str | Expression, + *expressions: ExpOrStr, append: bool = True, dialect: DialectType = None, copy: bool = True, @@ -2170,7 +2298,7 @@ class Select(Subqueryable): def select( self, - *expressions: str | Expression, + *expressions: ExpOrStr, append: bool = True, dialect: DialectType = None, copy: bool = True, @@ -2875,6 +3003,20 @@ class Dot(Binary): def name(self) -> str: return self.expression.name + @classmethod + def build(self, expressions: t.Sequence[Expression]) -> Dot: + """Build a Dot object with a sequence of expressions.""" + if len(expressions) < 2: + raise ValueError(f"Dot requires >= 2 expressions.") + + a, b, *expressions = expressions + dot = Dot(this=a, expression=b) + + for expression in expressions: + dot = Dot(this=dot, expression=expression) + + return dot + class DPipe(Binary): pass @@ -3049,7 +3191,7 @@ class TimeUnit(Expression): def __init__(self, **args): unit = args.get("unit") - if isinstance(unit, Column): + if isinstance(unit, (Column, Literal)): args["unit"] = Var(this=unit.name) elif isinstance(unit, Week): unit.set("this", Var(this=unit.this.name)) @@ -3261,6 +3403,10 @@ class Count(AggFunc): arg_types = {"this": False} +class CountIf(AggFunc): + pass + + class CurrentDate(Func): arg_types = {"this": False} @@ -3407,6 +3553,10 @@ class Explode(Func): pass +class ExponentialTimeDecayedAvg(AggFunc): + arg_types = {"this": True, "time": False, "decay": False} + + class Floor(Func): arg_types = {"this": True, "decimals": False} @@ -3420,10 +3570,18 @@ class GroupConcat(Func): arg_types = {"this": True, "separator": False} +class GroupUniqArray(AggFunc): + arg_types = {"this": True, "size": False} + + class Hex(Func): pass +class Histogram(AggFunc): + arg_types = {"this": True, "bins": False} + + class If(Func): arg_types = {"this": True, "true": True, "false": False} @@ -3493,7 +3651,11 @@ class Log10(Func): class LogicalOr(AggFunc): - _sql_names = ["LOGICAL_OR", "BOOL_OR"] + _sql_names = ["LOGICAL_OR", "BOOL_OR", "BOOLOR_AGG"] + + +class LogicalAnd(AggFunc): + _sql_names = ["LOGICAL_AND", "BOOL_AND", "BOOLAND_AGG"] class Lower(Func): @@ -3561,6 +3723,7 @@ class Quantile(AggFunc): # https://clickhouse.com/docs/en/sql-reference/aggregate-functions/reference/quantiles/#quantiles class Quantiles(AggFunc): arg_types = {"parameters": True, "expressions": True} + is_var_len_args = True class QuantileIf(AggFunc): @@ -3830,7 +3993,7 @@ ALL_FUNCTIONS = subclasses(__name__, Func, (AggFunc, Anonymous, Func)) # Helpers def maybe_parse( - sql_or_expression: str | Expression, + sql_or_expression: ExpOrStr, *, into: t.Optional[IntoType] = None, dialect: DialectType = None, @@ -4091,7 +4254,7 @@ def except_(left, right, distinct=True, dialect=None, **opts): return Except(this=left, expression=right, distinct=distinct) -def select(*expressions: str | Expression, dialect: DialectType = None, **opts) -> Select: +def select(*expressions: ExpOrStr, dialect: DialectType = None, **opts) -> Select: """ Initializes a syntax tree from one or multiple SELECT expressions. @@ -4135,7 +4298,14 @@ def from_(*expressions, dialect=None, **opts) -> Select: return Select().from_(*expressions, dialect=dialect, **opts) -def update(table, properties, where=None, from_=None, dialect=None, **opts) -> Update: +def update( + table: str | Table, + properties: dict, + where: t.Optional[ExpOrStr] = None, + from_: t.Optional[ExpOrStr] = None, + dialect: DialectType = None, + **opts, +) -> Update: """ Creates an update statement. @@ -4144,18 +4314,18 @@ def update(table, properties, where=None, from_=None, dialect=None, **opts) -> U "UPDATE my_table SET x = 1, y = '2', z = NULL FROM baz WHERE id > 1" Args: - *properties (Dict[str, Any]): dictionary of properties to set which are + *properties: dictionary of properties to set which are auto converted to sql objects eg None -> NULL - where (str): sql conditional parsed into a WHERE statement - from_ (str): sql statement parsed into a FROM statement - dialect (str): the dialect used to parse the input expressions. + where: sql conditional parsed into a WHERE statement + from_: sql statement parsed into a FROM statement + dialect: the dialect used to parse the input expressions. **opts: other options to use to parse the input expressions. Returns: Update: the syntax tree for the UPDATE statement. """ - update = Update(this=maybe_parse(table, into=Table, dialect=dialect)) - update.set( + update_expr = Update(this=maybe_parse(table, into=Table, dialect=dialect)) + update_expr.set( "expressions", [ EQ(this=maybe_parse(k, dialect=dialect, **opts), expression=convert(v)) @@ -4163,21 +4333,27 @@ def update(table, properties, where=None, from_=None, dialect=None, **opts) -> U ], ) if from_: - update.set( + update_expr.set( "from", maybe_parse(from_, into=From, dialect=dialect, prefix="FROM", **opts), ) if isinstance(where, Condition): where = Where(this=where) if where: - update.set( + update_expr.set( "where", maybe_parse(where, into=Where, dialect=dialect, prefix="WHERE", **opts), ) - return update + return update_expr -def delete(table, where=None, dialect=None, **opts) -> Delete: +def delete( + table: ExpOrStr, + where: t.Optional[ExpOrStr] = None, + returning: t.Optional[ExpOrStr] = None, + dialect: DialectType = None, + **opts, +) -> Delete: """ Builds a delete statement. @@ -4186,19 +4362,20 @@ def delete(table, where=None, dialect=None, **opts) -> Delete: 'DELETE FROM my_table WHERE id > 1' Args: - where (str|Condition): sql conditional parsed into a WHERE statement - dialect (str): the dialect used to parse the input expressions. + where: sql conditional parsed into a WHERE statement + returning: sql conditional parsed into a RETURNING statement + dialect: the dialect used to parse the input expressions. **opts: other options to use to parse the input expressions. Returns: Delete: the syntax tree for the DELETE statement. """ - return Delete( - this=maybe_parse(table, into=Table, dialect=dialect, **opts), - where=Where(this=where) - if isinstance(where, Condition) - else maybe_parse(where, into=Where, dialect=dialect, prefix="WHERE", **opts), - ) + delete_expr = Delete().delete(table, dialect=dialect, copy=False, **opts) + if where: + delete_expr = delete_expr.where(where, dialect=dialect, copy=False, **opts) + if returning: + delete_expr = delete_expr.returning(returning, dialect=dialect, copy=False, **opts) + return delete_expr def condition(expression, dialect=None, **opts) -> Condition: @@ -4414,7 +4591,7 @@ def to_column(sql_path: str | Column, **kwargs) -> Column: def alias_( - expression: str | Expression, + expression: ExpOrStr, alias: str | Identifier, table: bool | t.Sequence[str | Identifier] = False, quoted: t.Optional[bool] = None, @@ -4516,7 +4693,7 @@ def column( ) -def cast(expression: str | Expression, to: str | DataType | DataType.Type, **opts) -> Cast: +def cast(expression: ExpOrStr, to: str | DataType | DataType.Type, **opts) -> Cast: """Cast an expression to a data type. Example: @@ -4595,7 +4772,7 @@ def values( ) -def var(name: t.Optional[str | Expression]) -> Var: +def var(name: t.Optional[ExpOrStr]) -> Var: """Build a SQL variable. Example: @@ -4612,7 +4789,7 @@ def var(name: t.Optional[str | Expression]) -> Var: The new variable node. """ if not name: - raise ValueError(f"Cannot convert empty name into var.") + raise ValueError("Cannot convert empty name into var.") if isinstance(name, Expression): name = name.name @@ -4682,7 +4859,7 @@ def convert(value) -> Expression: raise ValueError(f"Cannot convert {value}") -def replace_children(expression, fun): +def replace_children(expression, fun, *args, **kwargs): """ Replace children of an expression with the result of a lambda fun(child) -> exp. """ @@ -4694,7 +4871,7 @@ def replace_children(expression, fun): for cn in child_nodes: if isinstance(cn, Expression): - for child_node in ensure_collection(fun(cn)): + for child_node in ensure_collection(fun(cn, *args, **kwargs)): new_child_nodes.append(child_node) child_node.parent = expression child_node.arg_key = k -- cgit v1.2.3