summaryrefslogtreecommitdiffstats
path: root/sqlglot/expressions.py
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2023-03-19 10:22:09 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2023-03-19 10:22:09 +0000
commit66af5c6fc22f6f11e9ea807b274e011a6f64efb7 (patch)
tree08ceed3b311b7b343935c1e55941b9d15e6f56d8 /sqlglot/expressions.py
parentReleasing debian version 11.3.6-1. (diff)
downloadsqlglot-66af5c6fc22f6f11e9ea807b274e011a6f64efb7.tar.xz
sqlglot-66af5c6fc22f6f11e9ea807b274e011a6f64efb7.zip
Merging upstream version 11.4.1.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/expressions.py')
-rw-r--r--sqlglot/expressions.py251
1 files changed, 214 insertions, 37 deletions
diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py
index 0c345b3..b9da4cc 100644
--- a/sqlglot/expressions.py
+++ b/sqlglot/expressions.py
@@ -301,7 +301,7 @@ class Expression(metaclass=_Expression):
the specified types.
Args:
- expression_types (type): the expression type(s) to match.
+ expression_types: the expression type(s) to match.
Returns:
The node which matches the criteria or None if no such node was found.
@@ -314,7 +314,7 @@ class Expression(metaclass=_Expression):
yields those that match at least one of the specified expression types.
Args:
- expression_types (type): the expression type(s) to match.
+ expression_types: the expression type(s) to match.
Returns:
The generator object.
@@ -328,7 +328,7 @@ class Expression(metaclass=_Expression):
Returns a nearest parent matching expression_types.
Args:
- expression_types (type): the expression type(s) to match.
+ expression_types: the expression type(s) to match.
Returns:
The parent node.
@@ -336,8 +336,7 @@ class Expression(metaclass=_Expression):
ancestor = self.parent
while ancestor and not isinstance(ancestor, expression_types):
ancestor = ancestor.parent
- # ignore type because mypy doesn't know that we're checking type in the loop
- return ancestor # type: ignore[return-value]
+ return t.cast(E, ancestor)
@property
def parent_select(self):
@@ -549,8 +548,12 @@ class Expression(metaclass=_Expression):
def pop(self):
"""
Remove this expression from its AST.
+
+ Returns:
+ The popped expression.
"""
self.replace(None)
+ return self
def assert_is(self, type_):
"""
@@ -626,6 +629,7 @@ IntoType = t.Union[
t.Type[Expression],
t.Collection[t.Union[str, t.Type[Expression]]],
]
+ExpOrStr = t.Union[str, Expression]
class Condition(Expression):
@@ -809,7 +813,7 @@ class Describe(Expression):
class Set(Expression):
- arg_types = {"expressions": True}
+ arg_types = {"expressions": False}
class SetItem(Expression):
@@ -905,6 +909,23 @@ class Column(Condition):
def output_name(self) -> str:
return self.name
+ @property
+ def parts(self) -> t.List[Identifier]:
+ """Return the parts of a column in order catalog, db, table, name."""
+ return [part for part in reversed(list(self.args.values())) if part]
+
+ def to_dot(self) -> Dot:
+ """Converts the column into a dot expression."""
+ parts = self.parts
+ parent = self.parent
+
+ while parent:
+ if isinstance(parent, Dot):
+ parts.append(parent.expression)
+ parent = parent.parent
+
+ return Dot.build(parts)
+
class ColumnDef(Expression):
arg_types = {
@@ -1033,6 +1054,113 @@ class Constraint(Expression):
class Delete(Expression):
arg_types = {"with": False, "this": False, "using": False, "where": False, "returning": False}
+ def delete(
+ self,
+ table: ExpOrStr,
+ dialect: DialectType = None,
+ copy: bool = True,
+ **opts,
+ ) -> Delete:
+ """
+ Create a DELETE expression or replace the table on an existing DELETE expression.
+
+ Example:
+ >>> delete("tbl").sql()
+ 'DELETE FROM tbl'
+
+ Args:
+ table: the table from which to delete.
+ dialect: the dialect used to parse the input expression.
+ copy: if `False`, modify this expression instance in-place.
+ opts: other options to use to parse the input expressions.
+
+ Returns:
+ Delete: the modified expression.
+ """
+ return _apply_builder(
+ expression=table,
+ instance=self,
+ arg="this",
+ dialect=dialect,
+ into=Table,
+ copy=copy,
+ **opts,
+ )
+
+ def where(
+ self,
+ *expressions: ExpOrStr,
+ append: bool = True,
+ dialect: DialectType = None,
+ copy: bool = True,
+ **opts,
+ ) -> Delete:
+ """
+ Append to or set the WHERE expressions.
+
+ Example:
+ >>> delete("tbl").where("x = 'a' OR x < 'b'").sql()
+ "DELETE FROM tbl WHERE x = 'a' OR x < 'b'"
+
+ Args:
+ *expressions: the SQL code strings to parse.
+ If an `Expression` instance is passed, it will be used as-is.
+ Multiple expressions are combined with an AND operator.
+ append: if `True`, AND the new expressions to any existing expression.
+ Otherwise, this resets the expression.
+ dialect: the dialect used to parse the input expressions.
+ copy: if `False`, modify this expression instance in-place.
+ opts: other options to use to parse the input expressions.
+
+ Returns:
+ Delete: the modified expression.
+ """
+ return _apply_conjunction_builder(
+ *expressions,
+ instance=self,
+ arg="where",
+ append=append,
+ into=Where,
+ dialect=dialect,
+ copy=copy,
+ **opts,
+ )
+
+ def returning(
+ self,
+ expression: ExpOrStr,
+ dialect: DialectType = None,
+ copy: bool = True,
+ **opts,
+ ) -> Delete:
+ """
+ Set the RETURNING expression. Not supported by all dialects.
+
+ Example:
+ >>> delete("tbl").returning("*", dialect="postgres").sql()
+ 'DELETE FROM tbl RETURNING *'
+
+ Args:
+ expression: the SQL code strings to parse.
+ If an `Expression` instance is passed, it will be used as-is.
+ dialect: the dialect used to parse the input expressions.
+ copy: if `False`, modify this expression instance in-place.
+ opts: other options to use to parse the input expressions.
+
+ Returns:
+ Delete: the modified expression.
+ """
+ return _apply_builder(
+ expression=expression,
+ instance=self,
+ arg="returning",
+ prefix="RETURNING",
+ dialect=dialect,
+ copy=copy,
+ into=Returning,
+ **opts,
+ )
+
class Drop(Expression):
arg_types = {
@@ -1824,7 +1952,7 @@ class Union(Subqueryable):
def select(
self,
- *expressions: str | Expression,
+ *expressions: ExpOrStr,
append: bool = True,
dialect: DialectType = None,
copy: bool = True,
@@ -2170,7 +2298,7 @@ class Select(Subqueryable):
def select(
self,
- *expressions: str | Expression,
+ *expressions: ExpOrStr,
append: bool = True,
dialect: DialectType = None,
copy: bool = True,
@@ -2875,6 +3003,20 @@ class Dot(Binary):
def name(self) -> str:
return self.expression.name
+ @classmethod
+ def build(self, expressions: t.Sequence[Expression]) -> Dot:
+ """Build a Dot object with a sequence of expressions."""
+ if len(expressions) < 2:
+ raise ValueError(f"Dot requires >= 2 expressions.")
+
+ a, b, *expressions = expressions
+ dot = Dot(this=a, expression=b)
+
+ for expression in expressions:
+ dot = Dot(this=dot, expression=expression)
+
+ return dot
+
class DPipe(Binary):
pass
@@ -3049,7 +3191,7 @@ class TimeUnit(Expression):
def __init__(self, **args):
unit = args.get("unit")
- if isinstance(unit, Column):
+ if isinstance(unit, (Column, Literal)):
args["unit"] = Var(this=unit.name)
elif isinstance(unit, Week):
unit.set("this", Var(this=unit.this.name))
@@ -3261,6 +3403,10 @@ class Count(AggFunc):
arg_types = {"this": False}
+class CountIf(AggFunc):
+ pass
+
+
class CurrentDate(Func):
arg_types = {"this": False}
@@ -3407,6 +3553,10 @@ class Explode(Func):
pass
+class ExponentialTimeDecayedAvg(AggFunc):
+ arg_types = {"this": True, "time": False, "decay": False}
+
+
class Floor(Func):
arg_types = {"this": True, "decimals": False}
@@ -3420,10 +3570,18 @@ class GroupConcat(Func):
arg_types = {"this": True, "separator": False}
+class GroupUniqArray(AggFunc):
+ arg_types = {"this": True, "size": False}
+
+
class Hex(Func):
pass
+class Histogram(AggFunc):
+ arg_types = {"this": True, "bins": False}
+
+
class If(Func):
arg_types = {"this": True, "true": True, "false": False}
@@ -3493,7 +3651,11 @@ class Log10(Func):
class LogicalOr(AggFunc):
- _sql_names = ["LOGICAL_OR", "BOOL_OR"]
+ _sql_names = ["LOGICAL_OR", "BOOL_OR", "BOOLOR_AGG"]
+
+
+class LogicalAnd(AggFunc):
+ _sql_names = ["LOGICAL_AND", "BOOL_AND", "BOOLAND_AGG"]
class Lower(Func):
@@ -3561,6 +3723,7 @@ class Quantile(AggFunc):
# https://clickhouse.com/docs/en/sql-reference/aggregate-functions/reference/quantiles/#quantiles
class Quantiles(AggFunc):
arg_types = {"parameters": True, "expressions": True}
+ is_var_len_args = True
class QuantileIf(AggFunc):
@@ -3830,7 +3993,7 @@ ALL_FUNCTIONS = subclasses(__name__, Func, (AggFunc, Anonymous, Func))
# Helpers
def maybe_parse(
- sql_or_expression: str | Expression,
+ sql_or_expression: ExpOrStr,
*,
into: t.Optional[IntoType] = None,
dialect: DialectType = None,
@@ -4091,7 +4254,7 @@ def except_(left, right, distinct=True, dialect=None, **opts):
return Except(this=left, expression=right, distinct=distinct)
-def select(*expressions: str | Expression, dialect: DialectType = None, **opts) -> Select:
+def select(*expressions: ExpOrStr, dialect: DialectType = None, **opts) -> Select:
"""
Initializes a syntax tree from one or multiple SELECT expressions.
@@ -4135,7 +4298,14 @@ def from_(*expressions, dialect=None, **opts) -> Select:
return Select().from_(*expressions, dialect=dialect, **opts)
-def update(table, properties, where=None, from_=None, dialect=None, **opts) -> Update:
+def update(
+ table: str | Table,
+ properties: dict,
+ where: t.Optional[ExpOrStr] = None,
+ from_: t.Optional[ExpOrStr] = None,
+ dialect: DialectType = None,
+ **opts,
+) -> Update:
"""
Creates an update statement.
@@ -4144,18 +4314,18 @@ def update(table, properties, where=None, from_=None, dialect=None, **opts) -> U
"UPDATE my_table SET x = 1, y = '2', z = NULL FROM baz WHERE id > 1"
Args:
- *properties (Dict[str, Any]): dictionary of properties to set which are
+ *properties: dictionary of properties to set which are
auto converted to sql objects eg None -> NULL
- where (str): sql conditional parsed into a WHERE statement
- from_ (str): sql statement parsed into a FROM statement
- dialect (str): the dialect used to parse the input expressions.
+ where: sql conditional parsed into a WHERE statement
+ from_: sql statement parsed into a FROM statement
+ dialect: the dialect used to parse the input expressions.
**opts: other options to use to parse the input expressions.
Returns:
Update: the syntax tree for the UPDATE statement.
"""
- update = Update(this=maybe_parse(table, into=Table, dialect=dialect))
- update.set(
+ update_expr = Update(this=maybe_parse(table, into=Table, dialect=dialect))
+ update_expr.set(
"expressions",
[
EQ(this=maybe_parse(k, dialect=dialect, **opts), expression=convert(v))
@@ -4163,21 +4333,27 @@ def update(table, properties, where=None, from_=None, dialect=None, **opts) -> U
],
)
if from_:
- update.set(
+ update_expr.set(
"from",
maybe_parse(from_, into=From, dialect=dialect, prefix="FROM", **opts),
)
if isinstance(where, Condition):
where = Where(this=where)
if where:
- update.set(
+ update_expr.set(
"where",
maybe_parse(where, into=Where, dialect=dialect, prefix="WHERE", **opts),
)
- return update
+ return update_expr
-def delete(table, where=None, dialect=None, **opts) -> Delete:
+def delete(
+ table: ExpOrStr,
+ where: t.Optional[ExpOrStr] = None,
+ returning: t.Optional[ExpOrStr] = None,
+ dialect: DialectType = None,
+ **opts,
+) -> Delete:
"""
Builds a delete statement.
@@ -4186,19 +4362,20 @@ def delete(table, where=None, dialect=None, **opts) -> Delete:
'DELETE FROM my_table WHERE id > 1'
Args:
- where (str|Condition): sql conditional parsed into a WHERE statement
- dialect (str): the dialect used to parse the input expressions.
+ where: sql conditional parsed into a WHERE statement
+ returning: sql conditional parsed into a RETURNING statement
+ dialect: the dialect used to parse the input expressions.
**opts: other options to use to parse the input expressions.
Returns:
Delete: the syntax tree for the DELETE statement.
"""
- return Delete(
- this=maybe_parse(table, into=Table, dialect=dialect, **opts),
- where=Where(this=where)
- if isinstance(where, Condition)
- else maybe_parse(where, into=Where, dialect=dialect, prefix="WHERE", **opts),
- )
+ delete_expr = Delete().delete(table, dialect=dialect, copy=False, **opts)
+ if where:
+ delete_expr = delete_expr.where(where, dialect=dialect, copy=False, **opts)
+ if returning:
+ delete_expr = delete_expr.returning(returning, dialect=dialect, copy=False, **opts)
+ return delete_expr
def condition(expression, dialect=None, **opts) -> Condition:
@@ -4414,7 +4591,7 @@ def to_column(sql_path: str | Column, **kwargs) -> Column:
def alias_(
- expression: str | Expression,
+ expression: ExpOrStr,
alias: str | Identifier,
table: bool | t.Sequence[str | Identifier] = False,
quoted: t.Optional[bool] = None,
@@ -4516,7 +4693,7 @@ def column(
)
-def cast(expression: str | Expression, to: str | DataType | DataType.Type, **opts) -> Cast:
+def cast(expression: ExpOrStr, to: str | DataType | DataType.Type, **opts) -> Cast:
"""Cast an expression to a data type.
Example:
@@ -4595,7 +4772,7 @@ def values(
)
-def var(name: t.Optional[str | Expression]) -> Var:
+def var(name: t.Optional[ExpOrStr]) -> Var:
"""Build a SQL variable.
Example:
@@ -4612,7 +4789,7 @@ def var(name: t.Optional[str | Expression]) -> Var:
The new variable node.
"""
if not name:
- raise ValueError(f"Cannot convert empty name into var.")
+ raise ValueError("Cannot convert empty name into var.")
if isinstance(name, Expression):
name = name.name
@@ -4682,7 +4859,7 @@ def convert(value) -> Expression:
raise ValueError(f"Cannot convert {value}")
-def replace_children(expression, fun):
+def replace_children(expression, fun, *args, **kwargs):
"""
Replace children of an expression with the result of a lambda fun(child) -> exp.
"""
@@ -4694,7 +4871,7 @@ def replace_children(expression, fun):
for cn in child_nodes:
if isinstance(cn, Expression):
- for child_node in ensure_collection(fun(cn)):
+ for child_node in ensure_collection(fun(cn, *args, **kwargs)):
new_child_nodes.append(child_node)
child_node.parent = expression
child_node.arg_key = k