summaryrefslogtreecommitdiffstats
path: root/sqlglot/expressions.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/expressions.py')
-rw-r--r--sqlglot/expressions.py286
1 files changed, 253 insertions, 33 deletions
diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py
index 9011dce..49d3ff6 100644
--- a/sqlglot/expressions.py
+++ b/sqlglot/expressions.py
@@ -701,6 +701,119 @@ class Condition(Expression):
"""
return not_(self)
+ def _binop(self, klass: t.Type[E], other: ExpOrStr, reverse=False) -> E:
+ this = self
+ other = convert(other)
+ if not isinstance(this, klass) and not isinstance(other, klass):
+ this = _wrap(this, Binary)
+ other = _wrap(other, Binary)
+ if reverse:
+ return klass(this=other, expression=this)
+ return klass(this=this, expression=other)
+
+ def __getitem__(self, other: ExpOrStr | slice | t.Tuple[ExpOrStr]):
+ if isinstance(other, slice):
+ return Between(
+ this=self,
+ low=convert(other.start),
+ high=convert(other.stop),
+ )
+ return Bracket(this=self, expressions=[convert(e) for e in ensure_list(other)])
+
+ def isin(self, *expressions: ExpOrStr, query: t.Optional[ExpOrStr] = None, **opts) -> In:
+ return In(
+ this=self,
+ expressions=[convert(e) for e in expressions],
+ query=maybe_parse(query, **opts) if query else None,
+ )
+
+ def like(self, other: ExpOrStr) -> Like:
+ return self._binop(Like, other)
+
+ def ilike(self, other: ExpOrStr) -> ILike:
+ return self._binop(ILike, other)
+
+ def eq(self, other: ExpOrStr) -> EQ:
+ return self._binop(EQ, other)
+
+ def neq(self, other: ExpOrStr) -> NEQ:
+ return self._binop(NEQ, other)
+
+ def rlike(self, other: ExpOrStr) -> RegexpLike:
+ return self._binop(RegexpLike, other)
+
+ def __lt__(self, other: ExpOrStr) -> LT:
+ return self._binop(LT, other)
+
+ def __le__(self, other: ExpOrStr) -> LTE:
+ return self._binop(LTE, other)
+
+ def __gt__(self, other: ExpOrStr) -> GT:
+ return self._binop(GT, other)
+
+ def __ge__(self, other: ExpOrStr) -> GTE:
+ return self._binop(GTE, other)
+
+ def __add__(self, other: ExpOrStr) -> Add:
+ return self._binop(Add, other)
+
+ def __radd__(self, other: ExpOrStr) -> Add:
+ return self._binop(Add, other, reverse=True)
+
+ def __sub__(self, other: ExpOrStr) -> Sub:
+ return self._binop(Sub, other)
+
+ def __rsub__(self, other: ExpOrStr) -> Sub:
+ return self._binop(Sub, other, reverse=True)
+
+ def __mul__(self, other: ExpOrStr) -> Mul:
+ return self._binop(Mul, other)
+
+ def __rmul__(self, other: ExpOrStr) -> Mul:
+ return self._binop(Mul, other, reverse=True)
+
+ def __truediv__(self, other: ExpOrStr) -> Div:
+ return self._binop(Div, other)
+
+ def __rtruediv__(self, other: ExpOrStr) -> Div:
+ return self._binop(Div, other, reverse=True)
+
+ def __floordiv__(self, other: ExpOrStr) -> IntDiv:
+ return self._binop(IntDiv, other)
+
+ def __rfloordiv__(self, other: ExpOrStr) -> IntDiv:
+ return self._binop(IntDiv, other, reverse=True)
+
+ def __mod__(self, other: ExpOrStr) -> Mod:
+ return self._binop(Mod, other)
+
+ def __rmod__(self, other: ExpOrStr) -> Mod:
+ return self._binop(Mod, other, reverse=True)
+
+ def __pow__(self, other: ExpOrStr) -> Pow:
+ return self._binop(Pow, other)
+
+ def __rpow__(self, other: ExpOrStr) -> Pow:
+ return self._binop(Pow, other, reverse=True)
+
+ def __and__(self, other: ExpOrStr) -> And:
+ return self._binop(And, other)
+
+ def __rand__(self, other: ExpOrStr) -> And:
+ return self._binop(And, other, reverse=True)
+
+ def __or__(self, other: ExpOrStr) -> Or:
+ return self._binop(Or, other)
+
+ def __ror__(self, other: ExpOrStr) -> Or:
+ return self._binop(Or, other, reverse=True)
+
+ def __neg__(self) -> Neg:
+ return Neg(this=_wrap(self, Binary))
+
+ def __invert__(self) -> Not:
+ return not_(self)
+
class Predicate(Condition):
"""Relationships like x = y, x > 1, x >= y."""
@@ -818,7 +931,6 @@ class Create(Expression):
"properties": False,
"replace": False,
"unique": False,
- "volatile": False,
"indexes": False,
"no_schema_binding": False,
"begin": False,
@@ -1053,6 +1165,11 @@ class NotNullColumnConstraint(ColumnConstraintKind):
arg_types = {"allow_null": False}
+# https://dev.mysql.com/doc/refman/5.7/en/timestamp-initialization.html
+class OnUpdateColumnConstraint(ColumnConstraintKind):
+ pass
+
+
class PrimaryKeyColumnConstraint(ColumnConstraintKind):
arg_types = {"desc": False}
@@ -1197,6 +1314,7 @@ class Drop(Expression):
"materialized": False,
"cascade": False,
"constraints": False,
+ "purge": False,
}
@@ -1287,6 +1405,7 @@ class Insert(Expression):
"with": False,
"this": True,
"expression": False,
+ "conflict": False,
"returning": False,
"overwrite": False,
"exists": False,
@@ -1295,6 +1414,16 @@ class Insert(Expression):
}
+class OnConflict(Expression):
+ arg_types = {
+ "duplicate": False,
+ "expressions": False,
+ "nothing": False,
+ "key": False,
+ "constraint": False,
+ }
+
+
class Returning(Expression):
arg_types = {"expressions": True}
@@ -1326,7 +1455,12 @@ class Partition(Expression):
class Fetch(Expression):
- arg_types = {"direction": False, "count": False}
+ arg_types = {
+ "direction": False,
+ "count": False,
+ "percent": False,
+ "with_ties": False,
+ }
class Group(Expression):
@@ -1374,6 +1508,7 @@ class Join(Expression):
"kind": False,
"using": False,
"natural": False,
+ "hint": False,
}
@property
@@ -1385,6 +1520,10 @@ class Join(Expression):
return self.text("side").upper()
@property
+ def hint(self):
+ return self.text("hint").upper()
+
+ @property
def alias_or_name(self):
return self.this.alias_or_name
@@ -1475,6 +1614,7 @@ class MatchRecognize(Expression):
"after": False,
"pattern": False,
"define": False,
+ "alias": False,
}
@@ -1582,6 +1722,10 @@ class FreespaceProperty(Property):
arg_types = {"this": True, "percent": False}
+class InputOutputFormat(Expression):
+ arg_types = {"input_format": False, "output_format": False}
+
+
class IsolatedLoadingProperty(Property):
arg_types = {
"no": True,
@@ -1646,6 +1790,10 @@ class ReturnsProperty(Property):
arg_types = {"this": True, "is_table": False, "table": False}
+class RowFormatProperty(Property):
+ arg_types = {"this": True}
+
+
class RowFormatDelimitedProperty(Property):
# https://cwiki.apache.org/confluence/display/hive/languagemanual+dml
arg_types = {
@@ -1683,6 +1831,10 @@ class SqlSecurityProperty(Property):
arg_types = {"definer": True}
+class StabilityProperty(Property):
+ arg_types = {"this": True}
+
+
class TableFormatProperty(Property):
arg_types = {"this": True}
@@ -1695,8 +1847,8 @@ class TransientProperty(Property):
arg_types = {"this": False}
-class VolatilityProperty(Property):
- arg_types = {"this": True}
+class VolatileProperty(Property):
+ arg_types = {"this": False}
class WithDataProperty(Property):
@@ -1726,6 +1878,7 @@ class Properties(Expression):
"LOCATION": LocationProperty,
"PARTITIONED_BY": PartitionedByProperty,
"RETURNS": ReturnsProperty,
+ "ROW_FORMAT": RowFormatProperty,
"SORTKEY": SortKeyProperty,
"TABLE_FORMAT": TableFormatProperty,
}
@@ -2721,6 +2874,7 @@ class Pivot(Expression):
"expressions": True,
"field": True,
"unpivot": True,
+ "columns": False,
}
@@ -2731,6 +2885,8 @@ class Window(Expression):
"order": False,
"spec": False,
"alias": False,
+ "over": False,
+ "first": False,
}
@@ -2816,6 +2972,7 @@ class DataType(Expression):
FLOAT = auto()
DOUBLE = auto()
DECIMAL = auto()
+ BIGDECIMAL = auto()
BIT = auto()
BOOLEAN = auto()
JSON = auto()
@@ -2964,7 +3121,7 @@ class DropPartition(Expression):
# Binary expressions like (ADD a b)
-class Binary(Expression):
+class Binary(Condition):
arg_types = {"this": True, "expression": True}
@property
@@ -2980,7 +3137,7 @@ class Add(Binary):
pass
-class Connector(Binary, Condition):
+class Connector(Binary):
pass
@@ -3142,7 +3299,7 @@ class ArrayOverlaps(Binary):
# Unary Expressions
# (NOT a)
-class Unary(Expression):
+class Unary(Condition):
pass
@@ -3150,11 +3307,11 @@ class BitwiseNot(Unary):
pass
-class Not(Unary, Condition):
+class Not(Unary):
pass
-class Paren(Unary, Condition):
+class Paren(Unary):
arg_types = {"this": True, "with": False}
@@ -3162,7 +3319,6 @@ class Neg(Unary):
pass
-# Special Functions
class Alias(Expression):
arg_types = {"this": True, "alias": False}
@@ -3381,6 +3537,16 @@ class AnyValue(AggFunc):
class Case(Func):
arg_types = {"this": False, "ifs": True, "default": False}
+ def when(self, condition: ExpOrStr, then: ExpOrStr, copy: bool = True, **opts) -> Case:
+ this = self.copy() if copy else self
+ this.append("ifs", If(this=maybe_parse(condition, **opts), true=maybe_parse(then, **opts)))
+ return this
+
+ def else_(self, condition: ExpOrStr, copy: bool = True, **opts) -> Case:
+ this = self.copy() if copy else self
+ this.set("default", maybe_parse(condition, **opts))
+ return this
+
class Cast(Func):
arg_types = {"this": True, "to": True}
@@ -3719,6 +3885,10 @@ class Map(Func):
arg_types = {"keys": False, "values": False}
+class StarMap(Func):
+ pass
+
+
class VarMap(Func):
arg_types = {"keys": True, "values": True}
is_var_len_args = True
@@ -3734,6 +3904,10 @@ class Max(AggFunc):
is_var_len_args = True
+class MD5(Func):
+ _sql_names = ["MD5"]
+
+
class Min(AggFunc):
arg_types = {"this": True, "expressions": False}
is_var_len_args = True
@@ -3840,6 +4014,15 @@ class SetAgg(AggFunc):
pass
+class SHA(Func):
+ _sql_names = ["SHA", "SHA1"]
+
+
+class SHA2(Func):
+ _sql_names = ["SHA2"]
+ arg_types = {"this": True, "length": False}
+
+
class SortArray(Func):
arg_types = {"this": True, "asc": False}
@@ -4017,6 +4200,12 @@ class When(Func):
arg_types = {"matched": True, "source": False, "condition": False, "then": True}
+# https://docs.oracle.com/javadb/10.8.3.0/ref/rrefsqljnextvaluefor.html
+# https://learn.microsoft.com/en-us/sql/t-sql/functions/next-value-for-transact-sql?view=sql-server-ver16
+class NextValueFor(Func):
+ arg_types = {"this": True, "order": False}
+
+
def _norm_arg(arg):
return arg.lower() if type(arg) is str else arg
@@ -4025,6 +4214,32 @@ ALL_FUNCTIONS = subclasses(__name__, Func, (AggFunc, Anonymous, Func))
# Helpers
+@t.overload
+def maybe_parse(
+ sql_or_expression: ExpOrStr,
+ *,
+ into: t.Type[E],
+ dialect: DialectType = None,
+ prefix: t.Optional[str] = None,
+ copy: bool = False,
+ **opts,
+) -> E:
+ ...
+
+
+@t.overload
+def maybe_parse(
+ sql_or_expression: str | E,
+ *,
+ into: t.Optional[IntoType] = None,
+ dialect: DialectType = None,
+ prefix: t.Optional[str] = None,
+ copy: bool = False,
+ **opts,
+) -> E:
+ ...
+
+
def maybe_parse(
sql_or_expression: ExpOrStr,
*,
@@ -4200,15 +4415,15 @@ def _combine(expressions, operator, dialect=None, **opts):
expressions = [condition(expression, dialect=dialect, **opts) for expression in expressions]
this = expressions[0]
if expressions[1:]:
- this = _wrap_operator(this)
+ this = _wrap(this, Connector)
for expression in expressions[1:]:
- this = operator(this=this, expression=_wrap_operator(expression))
+ this = operator(this=this, expression=_wrap(expression, Connector))
return this
-def _wrap_operator(expression):
- if isinstance(expression, (And, Or, Not)):
- expression = Paren(this=expression)
+def _wrap(expression: E, kind: t.Type[Expression]) -> E | Paren:
+ if isinstance(expression, kind):
+ return Paren(this=expression)
return expression
@@ -4506,7 +4721,7 @@ def not_(expression, dialect=None, **opts) -> Not:
dialect=dialect,
**opts,
)
- return Not(this=_wrap_operator(this))
+ return Not(this=_wrap(this, Connector))
def paren(expression) -> Paren:
@@ -4657,6 +4872,8 @@ def alias_(
if table:
table_alias = TableAlias(this=alias)
+
+ exp = exp.copy() if isinstance(expression, Expression) else exp
exp.set("alias", table_alias)
if not isinstance(table, bool):
@@ -4864,16 +5081,22 @@ def convert(value) -> Expression:
"""
if isinstance(value, Expression):
return value
- if value is None:
- return NULL
- if isinstance(value, bool):
- return Boolean(this=value)
if isinstance(value, str):
return Literal.string(value)
- if isinstance(value, float) and math.isnan(value):
+ if isinstance(value, bool):
+ return Boolean(this=value)
+ if value is None or (isinstance(value, float) and math.isnan(value)):
return NULL
if isinstance(value, numbers.Number):
return Literal.number(value)
+ if isinstance(value, datetime.datetime):
+ datetime_literal = Literal.string(
+ (value if value.tzinfo else value.replace(tzinfo=datetime.timezone.utc)).isoformat()
+ )
+ return TimeStrToTime(this=datetime_literal)
+ if isinstance(value, datetime.date):
+ date_literal = Literal.string(value.strftime("%Y-%m-%d"))
+ return DateStrToDate(this=date_literal)
if isinstance(value, tuple):
return Tuple(expressions=[convert(v) for v in value])
if isinstance(value, list):
@@ -4883,14 +5106,6 @@ def convert(value) -> Expression:
keys=[convert(k) for k in value],
values=[convert(v) for v in value.values()],
)
- if isinstance(value, datetime.datetime):
- datetime_literal = Literal.string(
- (value if value.tzinfo else value.replace(tzinfo=datetime.timezone.utc)).isoformat()
- )
- return TimeStrToTime(this=datetime_literal)
- if isinstance(value, datetime.date):
- date_literal = Literal.string(value.strftime("%Y-%m-%d"))
- return DateStrToDate(this=date_literal)
raise ValueError(f"Cannot convert {value}")
@@ -5030,7 +5245,9 @@ def replace_placeholders(expression, *args, **kwargs):
return expression.transform(_replace_placeholders, iter(args), **kwargs)
-def expand(expression: Expression, sources: t.Dict[str, Subqueryable], copy=True) -> Expression:
+def expand(
+ expression: Expression, sources: t.Dict[str, Subqueryable], copy: bool = True
+) -> Expression:
"""Transforms an expression by expanding all referenced sources into subqueries.
Examples:
@@ -5038,6 +5255,9 @@ def expand(expression: Expression, sources: t.Dict[str, Subqueryable], copy=True
>>> expand(parse_one("select * from x AS z"), {"x": parse_one("select * from y")}).sql()
'SELECT * FROM (SELECT * FROM y) AS z /* source: x */'
+ >>> expand(parse_one("select * from x AS z"), {"x": parse_one("select * from y"), "y": parse_one("select * from z")}).sql()
+ 'SELECT * FROM (SELECT * FROM (SELECT * FROM z) AS y /* source: y */) AS z /* source: x */'
+
Args:
expression: The expression to expand.
sources: A dictionary of name to Subqueryables.
@@ -5054,7 +5274,7 @@ def expand(expression: Expression, sources: t.Dict[str, Subqueryable], copy=True
if source:
subquery = source.subquery(node.alias or name)
subquery.comments = [f"source: {name}"]
- return subquery
+ return subquery.transform(_expand, copy=False)
return node
return expression.transform(_expand, copy=copy)
@@ -5089,8 +5309,8 @@ def func(name: str, *args, dialect: DialectType = None, **kwargs) -> Func:
from sqlglot.dialects.dialect import Dialect
- converted = [convert(arg) for arg in args]
- kwargs = {key: convert(value) for key, value in kwargs.items()}
+ converted: t.List[Expression] = [maybe_parse(arg, dialect=dialect) for arg in args]
+ kwargs = {key: maybe_parse(value, dialect=dialect) for key, value in kwargs.items()}
parser = Dialect.get_or_raise(dialect)().parser()
from_args_list = parser.FUNCTIONS.get(name.upper())