summaryrefslogtreecommitdiffstats
path: root/sqlglot/expressions.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/expressions.py')
-rw-r--r--sqlglot/expressions.py225
1 files changed, 138 insertions, 87 deletions
diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py
index 49d3ff6..9e7379d 100644
--- a/sqlglot/expressions.py
+++ b/sqlglot/expressions.py
@@ -64,6 +64,13 @@ class Expression(metaclass=_Expression):
and representing expressions as strings.
arg_types: determines what arguments (child nodes) are supported by an expression. It
maps arg keys to booleans that indicate whether the corresponding args are optional.
+ parent: a reference to the parent expression (or None, in case of root expressions).
+ arg_key: the arg key an expression is associated with, i.e. the name its parent expression
+ uses to refer to it.
+ comments: a list of comments that are associated with a given expression. This is used in
+ order to preserve comments when transpiling SQL code.
+ _type: the `sqlglot.expressions.DataType` type of an expression. This is inferred by the
+ optimizer, in order to enable some transformations that require type information.
Example:
>>> class Foo(Expression):
@@ -74,13 +81,6 @@ class Expression(metaclass=_Expression):
Args:
args: a mapping used for retrieving the arguments of an expression, given their arg keys.
- parent: a reference to the parent expression (or None, in case of root expressions).
- arg_key: the arg key an expression is associated with, i.e. the name its parent expression
- uses to refer to it.
- comments: a list of comments that are associated with a given expression. This is used in
- order to preserve comments when transpiling SQL code.
- _type: the `sqlglot.expressions.DataType` type of an expression. This is inferred by the
- optimizer, in order to enable some transformations that require type information.
"""
key = "expression"
@@ -258,6 +258,12 @@ class Expression(metaclass=_Expression):
new.parent = self.parent
return new
+ def add_comments(self, comments: t.Optional[t.List[str]]) -> None:
+ if self.comments is None:
+ self.comments = []
+ if comments:
+ self.comments.extend(comments)
+
def append(self, arg_key, value):
"""
Appends value to arg_key if it's a list or sets it as a new list.
@@ -650,7 +656,7 @@ ExpOrStr = t.Union[str, Expression]
class Condition(Expression):
- def and_(self, *expressions, dialect=None, **opts):
+ def and_(self, *expressions, dialect=None, copy=True, **opts):
"""
AND this condition with one or multiple expressions.
@@ -662,14 +668,15 @@ class Condition(Expression):
*expressions (str | Expression): the SQL code strings to parse.
If an `Expression` instance is passed, it will be used as-is.
dialect (str): the dialect used to parse the input expression.
+ copy (bool): whether or not to copy the involved expressions (only applies to Expressions).
opts (kwargs): other options to use to parse the input expressions.
Returns:
And: the new condition.
"""
- return and_(self, *expressions, dialect=dialect, **opts)
+ return and_(self, *expressions, dialect=dialect, copy=copy, **opts)
- def or_(self, *expressions, dialect=None, **opts):
+ def or_(self, *expressions, dialect=None, copy=True, **opts):
"""
OR this condition with one or multiple expressions.
@@ -681,14 +688,15 @@ class Condition(Expression):
*expressions (str | Expression): the SQL code strings to parse.
If an `Expression` instance is passed, it will be used as-is.
dialect (str): the dialect used to parse the input expression.
+ copy (bool): whether or not to copy the involved expressions (only applies to Expressions).
opts (kwargs): other options to use to parse the input expressions.
Returns:
Or: the new condition.
"""
- return or_(self, *expressions, dialect=dialect, **opts)
+ return or_(self, *expressions, dialect=dialect, copy=copy, **opts)
- def not_(self):
+ def not_(self, copy=True):
"""
Wrap this condition with NOT.
@@ -696,14 +704,17 @@ class Condition(Expression):
>>> condition("x=1").not_().sql()
'NOT x = 1'
+ Args:
+ copy (bool): whether or not to copy this object.
+
Returns:
Not: the new condition.
"""
- return not_(self)
+ return not_(self, copy=copy)
def _binop(self, klass: t.Type[E], other: ExpOrStr, reverse=False) -> E:
- this = self
- other = convert(other)
+ this = self.copy()
+ other = convert(other, copy=True)
if not isinstance(this, klass) and not isinstance(other, klass):
this = _wrap(this, Binary)
other = _wrap(other, Binary)
@@ -711,20 +722,25 @@ class Condition(Expression):
return klass(this=other, expression=this)
return klass(this=this, expression=other)
- def __getitem__(self, other: ExpOrStr | slice | t.Tuple[ExpOrStr]):
- if isinstance(other, slice):
- return Between(
- this=self,
- low=convert(other.start),
- high=convert(other.stop),
- )
- return Bracket(this=self, expressions=[convert(e) for e in ensure_list(other)])
+ def __getitem__(self, other: ExpOrStr | t.Tuple[ExpOrStr]):
+ return Bracket(
+ this=self.copy(), expressions=[convert(e, copy=True) for e in ensure_list(other)]
+ )
- def isin(self, *expressions: ExpOrStr, query: t.Optional[ExpOrStr] = None, **opts) -> In:
+ def isin(
+ self, *expressions: t.Any, query: t.Optional[ExpOrStr] = None, copy=True, **opts
+ ) -> In:
return In(
- this=self,
- expressions=[convert(e) for e in expressions],
- query=maybe_parse(query, **opts) if query else None,
+ this=_maybe_copy(self, copy),
+ expressions=[convert(e, copy=copy) for e in expressions],
+ query=maybe_parse(query, copy=copy, **opts) if query else None,
+ )
+
+ def between(self, low: t.Any, high: t.Any, copy=True, **opts) -> Between:
+ return Between(
+ this=_maybe_copy(self, copy),
+ low=convert(low, copy=copy, **opts),
+ high=convert(high, copy=copy, **opts),
)
def like(self, other: ExpOrStr) -> Like:
@@ -809,10 +825,10 @@ class Condition(Expression):
return self._binop(Or, other, reverse=True)
def __neg__(self) -> Neg:
- return Neg(this=_wrap(self, Binary))
+ return Neg(this=_wrap(self.copy(), Binary))
def __invert__(self) -> Not:
- return not_(self)
+ return not_(self.copy())
class Predicate(Condition):
@@ -830,11 +846,7 @@ class DerivedTable(Expression):
@property
def selects(self):
- alias = self.args.get("alias")
-
- if alias:
- return alias.columns
- return []
+ return self.this.selects if isinstance(self.this, Subqueryable) else []
@property
def named_selects(self):
@@ -904,7 +916,10 @@ class Unionable(Expression):
class UDTF(DerivedTable, Unionable):
- pass
+ @property
+ def selects(self):
+ alias = self.args.get("alias")
+ return alias.columns if alias else []
class Cache(Expression):
@@ -1073,6 +1088,10 @@ class ColumnDef(Expression):
"position": False,
}
+ @property
+ def constraints(self) -> t.List[ColumnConstraint]:
+ return self.args.get("constraints") or []
+
class AlterColumn(Expression):
arg_types = {
@@ -1100,6 +1119,10 @@ class Comment(Expression):
class ColumnConstraint(Expression):
arg_types = {"this": False, "kind": True}
+ @property
+ def kind(self) -> ColumnConstraintKind:
+ return self.args["kind"]
+
class ColumnConstraintKind(Expression):
pass
@@ -1937,6 +1960,15 @@ class Reference(Expression):
class Tuple(Expression):
arg_types = {"expressions": False}
+ def isin(
+ self, *expressions: t.Any, query: t.Optional[ExpOrStr] = None, copy=True, **opts
+ ) -> In:
+ return In(
+ this=_maybe_copy(self, copy),
+ expressions=[convert(e, copy=copy) for e in expressions],
+ query=maybe_parse(query, copy=copy, **opts) if query else None,
+ )
+
class Subqueryable(Unionable):
def subquery(self, alias=None, copy=True) -> Subquery:
@@ -2236,6 +2268,8 @@ class Select(Subqueryable):
"expressions": False,
"hint": False,
"distinct": False,
+ "struct": False, # https://cloud.google.com/bigquery/docs/reference/standard-sql/query-syntax#return_query_results_as_a_value_table
+ "value": False,
"into": False,
"from": False,
**QUERY_MODIFIERS,
@@ -2611,7 +2645,7 @@ class Select(Subqueryable):
join.set("kind", kind.text)
if on:
- on = and_(*ensure_collection(on), dialect=dialect, **opts)
+ on = and_(*ensure_collection(on), dialect=dialect, copy=copy, **opts)
join.set("on", on)
if using:
@@ -2723,7 +2757,7 @@ class Select(Subqueryable):
**opts,
)
- def distinct(self, distinct=True, copy=True) -> Select:
+ def distinct(self, *ons: ExpOrStr, distinct: bool = True, copy: bool = True) -> Select:
"""
Set the OFFSET expression.
@@ -2732,14 +2766,16 @@ class Select(Subqueryable):
'SELECT DISTINCT x FROM tbl'
Args:
- distinct (bool): whether the Select should be distinct
- copy (bool): if `False`, modify this expression instance in-place.
+ ons: the expressions to distinct on
+ distinct: whether the Select should be distinct
+ copy: if `False`, modify this expression instance in-place.
Returns:
Select: the modified expression.
"""
instance = _maybe_copy(self, copy)
- instance.set("distinct", Distinct() if distinct else None)
+ on = Tuple(expressions=[maybe_parse(on, copy=copy) for on in ons]) if ons else None
+ instance.set("distinct", Distinct(on=on) if distinct else None)
return instance
def ctas(self, table, properties=None, dialect=None, copy=True, **opts) -> Create:
@@ -2969,6 +3005,10 @@ class DataType(Expression):
USMALLINT = auto()
BIGINT = auto()
UBIGINT = auto()
+ INT128 = auto()
+ UINT128 = auto()
+ INT256 = auto()
+ UINT256 = auto()
FLOAT = auto()
DOUBLE = auto()
DECIMAL = auto()
@@ -3022,6 +3062,8 @@ class DataType(Expression):
Type.TINYINT,
Type.SMALLINT,
Type.BIGINT,
+ Type.INT128,
+ Type.INT256,
}
FLOAT_TYPES = {
@@ -3069,10 +3111,6 @@ class PseudoType(Expression):
pass
-class StructKwarg(Expression):
- arg_types = {"this": True, "expression": True}
-
-
# WHERE x <OP> EXISTS|ALL|ANY|SOME(SELECT ...)
class SubqueryPredicate(Predicate):
pass
@@ -3538,14 +3576,20 @@ class Case(Func):
arg_types = {"this": False, "ifs": True, "default": False}
def when(self, condition: ExpOrStr, then: ExpOrStr, copy: bool = True, **opts) -> Case:
- this = self.copy() if copy else self
- this.append("ifs", If(this=maybe_parse(condition, **opts), true=maybe_parse(then, **opts)))
- return this
+ instance = _maybe_copy(self, copy)
+ instance.append(
+ "ifs",
+ If(
+ this=maybe_parse(condition, copy=copy, **opts),
+ true=maybe_parse(then, copy=copy, **opts),
+ ),
+ )
+ return instance
def else_(self, condition: ExpOrStr, copy: bool = True, **opts) -> Case:
- this = self.copy() if copy else self
- this.set("default", maybe_parse(condition, **opts))
- return this
+ instance = _maybe_copy(self, copy)
+ instance.set("default", maybe_parse(condition, copy=copy, **opts))
+ return instance
class Cast(Func):
@@ -3760,6 +3804,14 @@ class Floor(Func):
arg_types = {"this": True, "decimals": False}
+class FromBase64(Func):
+ pass
+
+
+class ToBase64(Func):
+ pass
+
+
class Greatest(Func):
arg_types = {"this": True, "expressions": False}
is_var_len_args = True
@@ -3930,11 +3982,11 @@ class Pow(Binary, Func):
class PercentileCont(AggFunc):
- pass
+ arg_types = {"this": True, "expression": False}
class PercentileDisc(AggFunc):
- pass
+ arg_types = {"this": True, "expression": False}
class Quantile(AggFunc):
@@ -4405,14 +4457,16 @@ def _apply_conjunction_builder(
if append and existing is not None:
expressions = [existing.this if into else existing] + list(expressions)
- node = and_(*expressions, dialect=dialect, **opts)
+ node = and_(*expressions, dialect=dialect, copy=copy, **opts)
inst.set(arg, into(this=node) if into else node)
return inst
-def _combine(expressions, operator, dialect=None, **opts):
- expressions = [condition(expression, dialect=dialect, **opts) for expression in expressions]
+def _combine(expressions, operator, dialect=None, copy=True, **opts):
+ expressions = [
+ condition(expression, dialect=dialect, copy=copy, **opts) for expression in expressions
+ ]
this = expressions[0]
if expressions[1:]:
this = _wrap(this, Connector)
@@ -4626,7 +4680,7 @@ def delete(
return delete_expr
-def condition(expression, dialect=None, **opts) -> Condition:
+def condition(expression, dialect=None, copy=True, **opts) -> Condition:
"""
Initialize a logical condition expression.
@@ -4645,6 +4699,7 @@ def condition(expression, dialect=None, **opts) -> Condition:
If an Expression instance is passed, this is used as-is.
dialect (str): the dialect used to parse the input expression (in the case that the
input expression is a SQL string).
+ copy (bool): Whether or not to copy `expression` (only applies to expressions).
**opts: other options to use to parse the input expressions (again, in the case
that the input expression is a SQL string).
@@ -4655,11 +4710,12 @@ def condition(expression, dialect=None, **opts) -> Condition:
expression,
into=Condition,
dialect=dialect,
+ copy=copy,
**opts,
)
-def and_(*expressions, dialect=None, **opts) -> And:
+def and_(*expressions, dialect=None, copy=True, **opts) -> And:
"""
Combine multiple conditions with an AND logical operator.
@@ -4671,15 +4727,16 @@ def and_(*expressions, dialect=None, **opts) -> And:
*expressions (str | Expression): the SQL code strings to parse.
If an Expression instance is passed, this is used as-is.
dialect (str): the dialect used to parse the input expression.
+ copy (bool): whether or not to copy `expressions` (only applies to Expressions).
**opts: other options to use to parse the input expressions.
Returns:
And: the new condition
"""
- return _combine(expressions, And, dialect, **opts)
+ return _combine(expressions, And, dialect, copy=copy, **opts)
-def or_(*expressions, dialect=None, **opts) -> Or:
+def or_(*expressions, dialect=None, copy=True, **opts) -> Or:
"""
Combine multiple conditions with an OR logical operator.
@@ -4691,15 +4748,16 @@ def or_(*expressions, dialect=None, **opts) -> Or:
*expressions (str | Expression): the SQL code strings to parse.
If an Expression instance is passed, this is used as-is.
dialect (str): the dialect used to parse the input expression.
+ copy (bool): whether or not to copy `expressions` (only applies to Expressions).
**opts: other options to use to parse the input expressions.
Returns:
Or: the new condition
"""
- return _combine(expressions, Or, dialect, **opts)
+ return _combine(expressions, Or, dialect, copy=copy, **opts)
-def not_(expression, dialect=None, **opts) -> Not:
+def not_(expression, dialect=None, copy=True, **opts) -> Not:
"""
Wrap a condition with a NOT operator.
@@ -4719,13 +4777,14 @@ def not_(expression, dialect=None, **opts) -> Not:
this = condition(
expression,
dialect=dialect,
+ copy=copy,
**opts,
)
return Not(this=_wrap(this, Connector))
-def paren(expression) -> Paren:
- return Paren(this=expression)
+def paren(expression, copy=True) -> Paren:
+ return Paren(this=_maybe_copy(expression, copy))
SAFE_IDENTIFIER_RE = re.compile(r"^[_a-zA-Z][\w]*$")
@@ -4998,29 +5057,20 @@ def values(
alias: optional alias
columns: Optional list of ordered column names or ordered dictionary of column names to types.
If either are provided then an alias is also required.
- If a dictionary is provided then the first column of the values will be casted to the expected type
- in order to help with type inference.
Returns:
Values: the Values expression object
"""
if columns and not alias:
raise ValueError("Alias is required when providing columns")
- table_alias = (
- TableAlias(this=to_identifier(alias), columns=[to_identifier(x) for x in columns])
- if columns
- else TableAlias(this=to_identifier(alias) if alias else None)
- )
- expressions = [convert(tup) for tup in values]
- if columns and isinstance(columns, dict):
- types = list(columns.values())
- expressions[0].set(
- "expressions",
- [cast(x, types[i]) for i, x in enumerate(expressions[0].expressions)],
- )
+
return Values(
- expressions=expressions,
- alias=table_alias,
+ expressions=[convert(tup) for tup in values],
+ alias=(
+ TableAlias(this=to_identifier(alias), columns=[to_identifier(x) for x in columns])
+ if columns
+ else (TableAlias(this=to_identifier(alias)) if alias else None)
+ ),
)
@@ -5068,19 +5118,20 @@ def rename_table(old_name: str | Table, new_name: str | Table) -> AlterTable:
)
-def convert(value) -> Expression:
+def convert(value: t.Any, copy: bool = False) -> Expression:
"""Convert a python value into an expression object.
Raises an error if a conversion is not possible.
Args:
- value (Any): a python object
+ value: A python object.
+ copy: Whether or not to copy `value` (only applies to Expressions and collections).
Returns:
- Expression: the equivalent expression object
+ Expression: the equivalent expression object.
"""
if isinstance(value, Expression):
- return value
+ return _maybe_copy(value, copy)
if isinstance(value, str):
return Literal.string(value)
if isinstance(value, bool):
@@ -5098,13 +5149,13 @@ def convert(value) -> Expression:
date_literal = Literal.string(value.strftime("%Y-%m-%d"))
return DateStrToDate(this=date_literal)
if isinstance(value, tuple):
- return Tuple(expressions=[convert(v) for v in value])
+ return Tuple(expressions=[convert(v, copy=copy) for v in value])
if isinstance(value, list):
- return Array(expressions=[convert(v) for v in value])
+ return Array(expressions=[convert(v, copy=copy) for v in value])
if isinstance(value, dict):
return Map(
- keys=[convert(k) for k in value],
- values=[convert(v) for v in value.values()],
+ keys=[convert(k, copy=copy) for k in value],
+ values=[convert(v, copy=copy) for v in value.values()],
)
raise ValueError(f"Cannot convert {value}")