From 67c28dbe67209effad83d93b850caba5ee1e20e3 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Wed, 3 May 2023 11:12:28 +0200 Subject: Merging upstream version 11.7.1. Signed-off-by: Daniel Baumann --- sqlglot/expressions.py | 286 +++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 253 insertions(+), 33 deletions(-) (limited to 'sqlglot/expressions.py') diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 9011dce..49d3ff6 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -701,6 +701,119 @@ class Condition(Expression): """ return not_(self) + def _binop(self, klass: t.Type[E], other: ExpOrStr, reverse=False) -> E: + this = self + other = convert(other) + if not isinstance(this, klass) and not isinstance(other, klass): + this = _wrap(this, Binary) + other = _wrap(other, Binary) + if reverse: + return klass(this=other, expression=this) + return klass(this=this, expression=other) + + def __getitem__(self, other: ExpOrStr | slice | t.Tuple[ExpOrStr]): + if isinstance(other, slice): + return Between( + this=self, + low=convert(other.start), + high=convert(other.stop), + ) + return Bracket(this=self, expressions=[convert(e) for e in ensure_list(other)]) + + def isin(self, *expressions: ExpOrStr, query: t.Optional[ExpOrStr] = None, **opts) -> In: + return In( + this=self, + expressions=[convert(e) for e in expressions], + query=maybe_parse(query, **opts) if query else None, + ) + + def like(self, other: ExpOrStr) -> Like: + return self._binop(Like, other) + + def ilike(self, other: ExpOrStr) -> ILike: + return self._binop(ILike, other) + + def eq(self, other: ExpOrStr) -> EQ: + return self._binop(EQ, other) + + def neq(self, other: ExpOrStr) -> NEQ: + return self._binop(NEQ, other) + + def rlike(self, other: ExpOrStr) -> RegexpLike: + return self._binop(RegexpLike, other) + + def __lt__(self, other: ExpOrStr) -> LT: + return self._binop(LT, other) + + def __le__(self, other: ExpOrStr) -> LTE: + return self._binop(LTE, other) + + def __gt__(self, other: ExpOrStr) -> GT: + return self._binop(GT, other) + + def __ge__(self, other: ExpOrStr) -> GTE: + return self._binop(GTE, other) + + def __add__(self, other: ExpOrStr) -> Add: + return self._binop(Add, other) + + def __radd__(self, other: ExpOrStr) -> Add: + return self._binop(Add, other, reverse=True) + + def __sub__(self, other: ExpOrStr) -> Sub: + return self._binop(Sub, other) + + def __rsub__(self, other: ExpOrStr) -> Sub: + return self._binop(Sub, other, reverse=True) + + def __mul__(self, other: ExpOrStr) -> Mul: + return self._binop(Mul, other) + + def __rmul__(self, other: ExpOrStr) -> Mul: + return self._binop(Mul, other, reverse=True) + + def __truediv__(self, other: ExpOrStr) -> Div: + return self._binop(Div, other) + + def __rtruediv__(self, other: ExpOrStr) -> Div: + return self._binop(Div, other, reverse=True) + + def __floordiv__(self, other: ExpOrStr) -> IntDiv: + return self._binop(IntDiv, other) + + def __rfloordiv__(self, other: ExpOrStr) -> IntDiv: + return self._binop(IntDiv, other, reverse=True) + + def __mod__(self, other: ExpOrStr) -> Mod: + return self._binop(Mod, other) + + def __rmod__(self, other: ExpOrStr) -> Mod: + return self._binop(Mod, other, reverse=True) + + def __pow__(self, other: ExpOrStr) -> Pow: + return self._binop(Pow, other) + + def __rpow__(self, other: ExpOrStr) -> Pow: + return self._binop(Pow, other, reverse=True) + + def __and__(self, other: ExpOrStr) -> And: + return self._binop(And, other) + + def __rand__(self, other: ExpOrStr) -> And: + return self._binop(And, other, reverse=True) + + def __or__(self, other: ExpOrStr) -> Or: + return self._binop(Or, other) + + def __ror__(self, other: ExpOrStr) -> Or: + return self._binop(Or, other, reverse=True) + + def __neg__(self) -> Neg: + return Neg(this=_wrap(self, Binary)) + + def __invert__(self) -> Not: + return not_(self) + class Predicate(Condition): """Relationships like x = y, x > 1, x >= y.""" @@ -818,7 +931,6 @@ class Create(Expression): "properties": False, "replace": False, "unique": False, - "volatile": False, "indexes": False, "no_schema_binding": False, "begin": False, @@ -1053,6 +1165,11 @@ class NotNullColumnConstraint(ColumnConstraintKind): arg_types = {"allow_null": False} +# https://dev.mysql.com/doc/refman/5.7/en/timestamp-initialization.html +class OnUpdateColumnConstraint(ColumnConstraintKind): + pass + + class PrimaryKeyColumnConstraint(ColumnConstraintKind): arg_types = {"desc": False} @@ -1197,6 +1314,7 @@ class Drop(Expression): "materialized": False, "cascade": False, "constraints": False, + "purge": False, } @@ -1287,6 +1405,7 @@ class Insert(Expression): "with": False, "this": True, "expression": False, + "conflict": False, "returning": False, "overwrite": False, "exists": False, @@ -1295,6 +1414,16 @@ class Insert(Expression): } +class OnConflict(Expression): + arg_types = { + "duplicate": False, + "expressions": False, + "nothing": False, + "key": False, + "constraint": False, + } + + class Returning(Expression): arg_types = {"expressions": True} @@ -1326,7 +1455,12 @@ class Partition(Expression): class Fetch(Expression): - arg_types = {"direction": False, "count": False} + arg_types = { + "direction": False, + "count": False, + "percent": False, + "with_ties": False, + } class Group(Expression): @@ -1374,6 +1508,7 @@ class Join(Expression): "kind": False, "using": False, "natural": False, + "hint": False, } @property @@ -1384,6 +1519,10 @@ class Join(Expression): def side(self): return self.text("side").upper() + @property + def hint(self): + return self.text("hint").upper() + @property def alias_or_name(self): return self.this.alias_or_name @@ -1475,6 +1614,7 @@ class MatchRecognize(Expression): "after": False, "pattern": False, "define": False, + "alias": False, } @@ -1582,6 +1722,10 @@ class FreespaceProperty(Property): arg_types = {"this": True, "percent": False} +class InputOutputFormat(Expression): + arg_types = {"input_format": False, "output_format": False} + + class IsolatedLoadingProperty(Property): arg_types = { "no": True, @@ -1646,6 +1790,10 @@ class ReturnsProperty(Property): arg_types = {"this": True, "is_table": False, "table": False} +class RowFormatProperty(Property): + arg_types = {"this": True} + + class RowFormatDelimitedProperty(Property): # https://cwiki.apache.org/confluence/display/hive/languagemanual+dml arg_types = { @@ -1683,6 +1831,10 @@ class SqlSecurityProperty(Property): arg_types = {"definer": True} +class StabilityProperty(Property): + arg_types = {"this": True} + + class TableFormatProperty(Property): arg_types = {"this": True} @@ -1695,8 +1847,8 @@ class TransientProperty(Property): arg_types = {"this": False} -class VolatilityProperty(Property): - arg_types = {"this": True} +class VolatileProperty(Property): + arg_types = {"this": False} class WithDataProperty(Property): @@ -1726,6 +1878,7 @@ class Properties(Expression): "LOCATION": LocationProperty, "PARTITIONED_BY": PartitionedByProperty, "RETURNS": ReturnsProperty, + "ROW_FORMAT": RowFormatProperty, "SORTKEY": SortKeyProperty, "TABLE_FORMAT": TableFormatProperty, } @@ -2721,6 +2874,7 @@ class Pivot(Expression): "expressions": True, "field": True, "unpivot": True, + "columns": False, } @@ -2731,6 +2885,8 @@ class Window(Expression): "order": False, "spec": False, "alias": False, + "over": False, + "first": False, } @@ -2816,6 +2972,7 @@ class DataType(Expression): FLOAT = auto() DOUBLE = auto() DECIMAL = auto() + BIGDECIMAL = auto() BIT = auto() BOOLEAN = auto() JSON = auto() @@ -2964,7 +3121,7 @@ class DropPartition(Expression): # Binary expressions like (ADD a b) -class Binary(Expression): +class Binary(Condition): arg_types = {"this": True, "expression": True} @property @@ -2980,7 +3137,7 @@ class Add(Binary): pass -class Connector(Binary, Condition): +class Connector(Binary): pass @@ -3142,7 +3299,7 @@ class ArrayOverlaps(Binary): # Unary Expressions # (NOT a) -class Unary(Expression): +class Unary(Condition): pass @@ -3150,11 +3307,11 @@ class BitwiseNot(Unary): pass -class Not(Unary, Condition): +class Not(Unary): pass -class Paren(Unary, Condition): +class Paren(Unary): arg_types = {"this": True, "with": False} @@ -3162,7 +3319,6 @@ class Neg(Unary): pass -# Special Functions class Alias(Expression): arg_types = {"this": True, "alias": False} @@ -3381,6 +3537,16 @@ class AnyValue(AggFunc): class Case(Func): arg_types = {"this": False, "ifs": True, "default": False} + def when(self, condition: ExpOrStr, then: ExpOrStr, copy: bool = True, **opts) -> Case: + this = self.copy() if copy else self + this.append("ifs", If(this=maybe_parse(condition, **opts), true=maybe_parse(then, **opts))) + return this + + def else_(self, condition: ExpOrStr, copy: bool = True, **opts) -> Case: + this = self.copy() if copy else self + this.set("default", maybe_parse(condition, **opts)) + return this + class Cast(Func): arg_types = {"this": True, "to": True} @@ -3719,6 +3885,10 @@ class Map(Func): arg_types = {"keys": False, "values": False} +class StarMap(Func): + pass + + class VarMap(Func): arg_types = {"keys": True, "values": True} is_var_len_args = True @@ -3734,6 +3904,10 @@ class Max(AggFunc): is_var_len_args = True +class MD5(Func): + _sql_names = ["MD5"] + + class Min(AggFunc): arg_types = {"this": True, "expressions": False} is_var_len_args = True @@ -3840,6 +4014,15 @@ class SetAgg(AggFunc): pass +class SHA(Func): + _sql_names = ["SHA", "SHA1"] + + +class SHA2(Func): + _sql_names = ["SHA2"] + arg_types = {"this": True, "length": False} + + class SortArray(Func): arg_types = {"this": True, "asc": False} @@ -4017,6 +4200,12 @@ class When(Func): arg_types = {"matched": True, "source": False, "condition": False, "then": True} +# https://docs.oracle.com/javadb/10.8.3.0/ref/rrefsqljnextvaluefor.html +# https://learn.microsoft.com/en-us/sql/t-sql/functions/next-value-for-transact-sql?view=sql-server-ver16 +class NextValueFor(Func): + arg_types = {"this": True, "order": False} + + def _norm_arg(arg): return arg.lower() if type(arg) is str else arg @@ -4025,6 +4214,32 @@ ALL_FUNCTIONS = subclasses(__name__, Func, (AggFunc, Anonymous, Func)) # Helpers +@t.overload +def maybe_parse( + sql_or_expression: ExpOrStr, + *, + into: t.Type[E], + dialect: DialectType = None, + prefix: t.Optional[str] = None, + copy: bool = False, + **opts, +) -> E: + ... + + +@t.overload +def maybe_parse( + sql_or_expression: str | E, + *, + into: t.Optional[IntoType] = None, + dialect: DialectType = None, + prefix: t.Optional[str] = None, + copy: bool = False, + **opts, +) -> E: + ... + + def maybe_parse( sql_or_expression: ExpOrStr, *, @@ -4200,15 +4415,15 @@ def _combine(expressions, operator, dialect=None, **opts): expressions = [condition(expression, dialect=dialect, **opts) for expression in expressions] this = expressions[0] if expressions[1:]: - this = _wrap_operator(this) + this = _wrap(this, Connector) for expression in expressions[1:]: - this = operator(this=this, expression=_wrap_operator(expression)) + this = operator(this=this, expression=_wrap(expression, Connector)) return this -def _wrap_operator(expression): - if isinstance(expression, (And, Or, Not)): - expression = Paren(this=expression) +def _wrap(expression: E, kind: t.Type[Expression]) -> E | Paren: + if isinstance(expression, kind): + return Paren(this=expression) return expression @@ -4506,7 +4721,7 @@ def not_(expression, dialect=None, **opts) -> Not: dialect=dialect, **opts, ) - return Not(this=_wrap_operator(this)) + return Not(this=_wrap(this, Connector)) def paren(expression) -> Paren: @@ -4657,6 +4872,8 @@ def alias_( if table: table_alias = TableAlias(this=alias) + + exp = exp.copy() if isinstance(expression, Expression) else exp exp.set("alias", table_alias) if not isinstance(table, bool): @@ -4864,16 +5081,22 @@ def convert(value) -> Expression: """ if isinstance(value, Expression): return value - if value is None: - return NULL - if isinstance(value, bool): - return Boolean(this=value) if isinstance(value, str): return Literal.string(value) - if isinstance(value, float) and math.isnan(value): + if isinstance(value, bool): + return Boolean(this=value) + if value is None or (isinstance(value, float) and math.isnan(value)): return NULL if isinstance(value, numbers.Number): return Literal.number(value) + if isinstance(value, datetime.datetime): + datetime_literal = Literal.string( + (value if value.tzinfo else value.replace(tzinfo=datetime.timezone.utc)).isoformat() + ) + return TimeStrToTime(this=datetime_literal) + if isinstance(value, datetime.date): + date_literal = Literal.string(value.strftime("%Y-%m-%d")) + return DateStrToDate(this=date_literal) if isinstance(value, tuple): return Tuple(expressions=[convert(v) for v in value]) if isinstance(value, list): @@ -4883,14 +5106,6 @@ def convert(value) -> Expression: keys=[convert(k) for k in value], values=[convert(v) for v in value.values()], ) - if isinstance(value, datetime.datetime): - datetime_literal = Literal.string( - (value if value.tzinfo else value.replace(tzinfo=datetime.timezone.utc)).isoformat() - ) - return TimeStrToTime(this=datetime_literal) - if isinstance(value, datetime.date): - date_literal = Literal.string(value.strftime("%Y-%m-%d")) - return DateStrToDate(this=date_literal) raise ValueError(f"Cannot convert {value}") @@ -5030,7 +5245,9 @@ def replace_placeholders(expression, *args, **kwargs): return expression.transform(_replace_placeholders, iter(args), **kwargs) -def expand(expression: Expression, sources: t.Dict[str, Subqueryable], copy=True) -> Expression: +def expand( + expression: Expression, sources: t.Dict[str, Subqueryable], copy: bool = True +) -> Expression: """Transforms an expression by expanding all referenced sources into subqueries. Examples: @@ -5038,6 +5255,9 @@ def expand(expression: Expression, sources: t.Dict[str, Subqueryable], copy=True >>> expand(parse_one("select * from x AS z"), {"x": parse_one("select * from y")}).sql() 'SELECT * FROM (SELECT * FROM y) AS z /* source: x */' + >>> expand(parse_one("select * from x AS z"), {"x": parse_one("select * from y"), "y": parse_one("select * from z")}).sql() + 'SELECT * FROM (SELECT * FROM (SELECT * FROM z) AS y /* source: y */) AS z /* source: x */' + Args: expression: The expression to expand. sources: A dictionary of name to Subqueryables. @@ -5054,7 +5274,7 @@ def expand(expression: Expression, sources: t.Dict[str, Subqueryable], copy=True if source: subquery = source.subquery(node.alias or name) subquery.comments = [f"source: {name}"] - return subquery + return subquery.transform(_expand, copy=False) return node return expression.transform(_expand, copy=copy) @@ -5089,8 +5309,8 @@ def func(name: str, *args, dialect: DialectType = None, **kwargs) -> Func: from sqlglot.dialects.dialect import Dialect - converted = [convert(arg) for arg in args] - kwargs = {key: convert(value) for key, value in kwargs.items()} + converted: t.List[Expression] = [maybe_parse(arg, dialect=dialect) for arg in args] + kwargs = {key: maybe_parse(value, dialect=dialect) for key, value in kwargs.items()} parser = Dialect.get_or_raise(dialect)().parser() from_args_list = parser.FUNCTIONS.get(name.upper()) -- cgit v1.2.3