diff options
Diffstat (limited to '')
-rw-r--r-- | sqlglot/expressions.py | 225 |
1 files changed, 138 insertions, 87 deletions
diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 49d3ff6..9e7379d 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -64,6 +64,13 @@ class Expression(metaclass=_Expression): and representing expressions as strings. arg_types: determines what arguments (child nodes) are supported by an expression. It maps arg keys to booleans that indicate whether the corresponding args are optional. + parent: a reference to the parent expression (or None, in case of root expressions). + arg_key: the arg key an expression is associated with, i.e. the name its parent expression + uses to refer to it. + comments: a list of comments that are associated with a given expression. This is used in + order to preserve comments when transpiling SQL code. + _type: the `sqlglot.expressions.DataType` type of an expression. This is inferred by the + optimizer, in order to enable some transformations that require type information. Example: >>> class Foo(Expression): @@ -74,13 +81,6 @@ class Expression(metaclass=_Expression): Args: args: a mapping used for retrieving the arguments of an expression, given their arg keys. - parent: a reference to the parent expression (or None, in case of root expressions). - arg_key: the arg key an expression is associated with, i.e. the name its parent expression - uses to refer to it. - comments: a list of comments that are associated with a given expression. This is used in - order to preserve comments when transpiling SQL code. - _type: the `sqlglot.expressions.DataType` type of an expression. This is inferred by the - optimizer, in order to enable some transformations that require type information. """ key = "expression" @@ -258,6 +258,12 @@ class Expression(metaclass=_Expression): new.parent = self.parent return new + def add_comments(self, comments: t.Optional[t.List[str]]) -> None: + if self.comments is None: + self.comments = [] + if comments: + self.comments.extend(comments) + def append(self, arg_key, value): """ Appends value to arg_key if it's a list or sets it as a new list. @@ -650,7 +656,7 @@ ExpOrStr = t.Union[str, Expression] class Condition(Expression): - def and_(self, *expressions, dialect=None, **opts): + def and_(self, *expressions, dialect=None, copy=True, **opts): """ AND this condition with one or multiple expressions. @@ -662,14 +668,15 @@ class Condition(Expression): *expressions (str | Expression): the SQL code strings to parse. If an `Expression` instance is passed, it will be used as-is. dialect (str): the dialect used to parse the input expression. + copy (bool): whether or not to copy the involved expressions (only applies to Expressions). opts (kwargs): other options to use to parse the input expressions. Returns: And: the new condition. """ - return and_(self, *expressions, dialect=dialect, **opts) + return and_(self, *expressions, dialect=dialect, copy=copy, **opts) - def or_(self, *expressions, dialect=None, **opts): + def or_(self, *expressions, dialect=None, copy=True, **opts): """ OR this condition with one or multiple expressions. @@ -681,14 +688,15 @@ class Condition(Expression): *expressions (str | Expression): the SQL code strings to parse. If an `Expression` instance is passed, it will be used as-is. dialect (str): the dialect used to parse the input expression. + copy (bool): whether or not to copy the involved expressions (only applies to Expressions). opts (kwargs): other options to use to parse the input expressions. Returns: Or: the new condition. """ - return or_(self, *expressions, dialect=dialect, **opts) + return or_(self, *expressions, dialect=dialect, copy=copy, **opts) - def not_(self): + def not_(self, copy=True): """ Wrap this condition with NOT. @@ -696,14 +704,17 @@ class Condition(Expression): >>> condition("x=1").not_().sql() 'NOT x = 1' + Args: + copy (bool): whether or not to copy this object. + Returns: Not: the new condition. """ - return not_(self) + return not_(self, copy=copy) def _binop(self, klass: t.Type[E], other: ExpOrStr, reverse=False) -> E: - this = self - other = convert(other) + this = self.copy() + other = convert(other, copy=True) if not isinstance(this, klass) and not isinstance(other, klass): this = _wrap(this, Binary) other = _wrap(other, Binary) @@ -711,20 +722,25 @@ class Condition(Expression): return klass(this=other, expression=this) return klass(this=this, expression=other) - def __getitem__(self, other: ExpOrStr | slice | t.Tuple[ExpOrStr]): - if isinstance(other, slice): - return Between( - this=self, - low=convert(other.start), - high=convert(other.stop), - ) - return Bracket(this=self, expressions=[convert(e) for e in ensure_list(other)]) + def __getitem__(self, other: ExpOrStr | t.Tuple[ExpOrStr]): + return Bracket( + this=self.copy(), expressions=[convert(e, copy=True) for e in ensure_list(other)] + ) - def isin(self, *expressions: ExpOrStr, query: t.Optional[ExpOrStr] = None, **opts) -> In: + def isin( + self, *expressions: t.Any, query: t.Optional[ExpOrStr] = None, copy=True, **opts + ) -> In: return In( - this=self, - expressions=[convert(e) for e in expressions], - query=maybe_parse(query, **opts) if query else None, + this=_maybe_copy(self, copy), + expressions=[convert(e, copy=copy) for e in expressions], + query=maybe_parse(query, copy=copy, **opts) if query else None, + ) + + def between(self, low: t.Any, high: t.Any, copy=True, **opts) -> Between: + return Between( + this=_maybe_copy(self, copy), + low=convert(low, copy=copy, **opts), + high=convert(high, copy=copy, **opts), ) def like(self, other: ExpOrStr) -> Like: @@ -809,10 +825,10 @@ class Condition(Expression): return self._binop(Or, other, reverse=True) def __neg__(self) -> Neg: - return Neg(this=_wrap(self, Binary)) + return Neg(this=_wrap(self.copy(), Binary)) def __invert__(self) -> Not: - return not_(self) + return not_(self.copy()) class Predicate(Condition): @@ -830,11 +846,7 @@ class DerivedTable(Expression): @property def selects(self): - alias = self.args.get("alias") - - if alias: - return alias.columns - return [] + return self.this.selects if isinstance(self.this, Subqueryable) else [] @property def named_selects(self): @@ -904,7 +916,10 @@ class Unionable(Expression): class UDTF(DerivedTable, Unionable): - pass + @property + def selects(self): + alias = self.args.get("alias") + return alias.columns if alias else [] class Cache(Expression): @@ -1073,6 +1088,10 @@ class ColumnDef(Expression): "position": False, } + @property + def constraints(self) -> t.List[ColumnConstraint]: + return self.args.get("constraints") or [] + class AlterColumn(Expression): arg_types = { @@ -1100,6 +1119,10 @@ class Comment(Expression): class ColumnConstraint(Expression): arg_types = {"this": False, "kind": True} + @property + def kind(self) -> ColumnConstraintKind: + return self.args["kind"] + class ColumnConstraintKind(Expression): pass @@ -1937,6 +1960,15 @@ class Reference(Expression): class Tuple(Expression): arg_types = {"expressions": False} + def isin( + self, *expressions: t.Any, query: t.Optional[ExpOrStr] = None, copy=True, **opts + ) -> In: + return In( + this=_maybe_copy(self, copy), + expressions=[convert(e, copy=copy) for e in expressions], + query=maybe_parse(query, copy=copy, **opts) if query else None, + ) + class Subqueryable(Unionable): def subquery(self, alias=None, copy=True) -> Subquery: @@ -2236,6 +2268,8 @@ class Select(Subqueryable): "expressions": False, "hint": False, "distinct": False, + "struct": False, # https://cloud.google.com/bigquery/docs/reference/standard-sql/query-syntax#return_query_results_as_a_value_table + "value": False, "into": False, "from": False, **QUERY_MODIFIERS, @@ -2611,7 +2645,7 @@ class Select(Subqueryable): join.set("kind", kind.text) if on: - on = and_(*ensure_collection(on), dialect=dialect, **opts) + on = and_(*ensure_collection(on), dialect=dialect, copy=copy, **opts) join.set("on", on) if using: @@ -2723,7 +2757,7 @@ class Select(Subqueryable): **opts, ) - def distinct(self, distinct=True, copy=True) -> Select: + def distinct(self, *ons: ExpOrStr, distinct: bool = True, copy: bool = True) -> Select: """ Set the OFFSET expression. @@ -2732,14 +2766,16 @@ class Select(Subqueryable): 'SELECT DISTINCT x FROM tbl' Args: - distinct (bool): whether the Select should be distinct - copy (bool): if `False`, modify this expression instance in-place. + ons: the expressions to distinct on + distinct: whether the Select should be distinct + copy: if `False`, modify this expression instance in-place. Returns: Select: the modified expression. """ instance = _maybe_copy(self, copy) - instance.set("distinct", Distinct() if distinct else None) + on = Tuple(expressions=[maybe_parse(on, copy=copy) for on in ons]) if ons else None + instance.set("distinct", Distinct(on=on) if distinct else None) return instance def ctas(self, table, properties=None, dialect=None, copy=True, **opts) -> Create: @@ -2969,6 +3005,10 @@ class DataType(Expression): USMALLINT = auto() BIGINT = auto() UBIGINT = auto() + INT128 = auto() + UINT128 = auto() + INT256 = auto() + UINT256 = auto() FLOAT = auto() DOUBLE = auto() DECIMAL = auto() @@ -3022,6 +3062,8 @@ class DataType(Expression): Type.TINYINT, Type.SMALLINT, Type.BIGINT, + Type.INT128, + Type.INT256, } FLOAT_TYPES = { @@ -3069,10 +3111,6 @@ class PseudoType(Expression): pass -class StructKwarg(Expression): - arg_types = {"this": True, "expression": True} - - # WHERE x <OP> EXISTS|ALL|ANY|SOME(SELECT ...) class SubqueryPredicate(Predicate): pass @@ -3538,14 +3576,20 @@ class Case(Func): arg_types = {"this": False, "ifs": True, "default": False} def when(self, condition: ExpOrStr, then: ExpOrStr, copy: bool = True, **opts) -> Case: - this = self.copy() if copy else self - this.append("ifs", If(this=maybe_parse(condition, **opts), true=maybe_parse(then, **opts))) - return this + instance = _maybe_copy(self, copy) + instance.append( + "ifs", + If( + this=maybe_parse(condition, copy=copy, **opts), + true=maybe_parse(then, copy=copy, **opts), + ), + ) + return instance def else_(self, condition: ExpOrStr, copy: bool = True, **opts) -> Case: - this = self.copy() if copy else self - this.set("default", maybe_parse(condition, **opts)) - return this + instance = _maybe_copy(self, copy) + instance.set("default", maybe_parse(condition, copy=copy, **opts)) + return instance class Cast(Func): @@ -3760,6 +3804,14 @@ class Floor(Func): arg_types = {"this": True, "decimals": False} +class FromBase64(Func): + pass + + +class ToBase64(Func): + pass + + class Greatest(Func): arg_types = {"this": True, "expressions": False} is_var_len_args = True @@ -3930,11 +3982,11 @@ class Pow(Binary, Func): class PercentileCont(AggFunc): - pass + arg_types = {"this": True, "expression": False} class PercentileDisc(AggFunc): - pass + arg_types = {"this": True, "expression": False} class Quantile(AggFunc): @@ -4405,14 +4457,16 @@ def _apply_conjunction_builder( if append and existing is not None: expressions = [existing.this if into else existing] + list(expressions) - node = and_(*expressions, dialect=dialect, **opts) + node = and_(*expressions, dialect=dialect, copy=copy, **opts) inst.set(arg, into(this=node) if into else node) return inst -def _combine(expressions, operator, dialect=None, **opts): - expressions = [condition(expression, dialect=dialect, **opts) for expression in expressions] +def _combine(expressions, operator, dialect=None, copy=True, **opts): + expressions = [ + condition(expression, dialect=dialect, copy=copy, **opts) for expression in expressions + ] this = expressions[0] if expressions[1:]: this = _wrap(this, Connector) @@ -4626,7 +4680,7 @@ def delete( return delete_expr -def condition(expression, dialect=None, **opts) -> Condition: +def condition(expression, dialect=None, copy=True, **opts) -> Condition: """ Initialize a logical condition expression. @@ -4645,6 +4699,7 @@ def condition(expression, dialect=None, **opts) -> Condition: If an Expression instance is passed, this is used as-is. dialect (str): the dialect used to parse the input expression (in the case that the input expression is a SQL string). + copy (bool): Whether or not to copy `expression` (only applies to expressions). **opts: other options to use to parse the input expressions (again, in the case that the input expression is a SQL string). @@ -4655,11 +4710,12 @@ def condition(expression, dialect=None, **opts) -> Condition: expression, into=Condition, dialect=dialect, + copy=copy, **opts, ) -def and_(*expressions, dialect=None, **opts) -> And: +def and_(*expressions, dialect=None, copy=True, **opts) -> And: """ Combine multiple conditions with an AND logical operator. @@ -4671,15 +4727,16 @@ def and_(*expressions, dialect=None, **opts) -> And: *expressions (str | Expression): the SQL code strings to parse. If an Expression instance is passed, this is used as-is. dialect (str): the dialect used to parse the input expression. + copy (bool): whether or not to copy `expressions` (only applies to Expressions). **opts: other options to use to parse the input expressions. Returns: And: the new condition """ - return _combine(expressions, And, dialect, **opts) + return _combine(expressions, And, dialect, copy=copy, **opts) -def or_(*expressions, dialect=None, **opts) -> Or: +def or_(*expressions, dialect=None, copy=True, **opts) -> Or: """ Combine multiple conditions with an OR logical operator. @@ -4691,15 +4748,16 @@ def or_(*expressions, dialect=None, **opts) -> Or: *expressions (str | Expression): the SQL code strings to parse. If an Expression instance is passed, this is used as-is. dialect (str): the dialect used to parse the input expression. + copy (bool): whether or not to copy `expressions` (only applies to Expressions). **opts: other options to use to parse the input expressions. Returns: Or: the new condition """ - return _combine(expressions, Or, dialect, **opts) + return _combine(expressions, Or, dialect, copy=copy, **opts) -def not_(expression, dialect=None, **opts) -> Not: +def not_(expression, dialect=None, copy=True, **opts) -> Not: """ Wrap a condition with a NOT operator. @@ -4719,13 +4777,14 @@ def not_(expression, dialect=None, **opts) -> Not: this = condition( expression, dialect=dialect, + copy=copy, **opts, ) return Not(this=_wrap(this, Connector)) -def paren(expression) -> Paren: - return Paren(this=expression) +def paren(expression, copy=True) -> Paren: + return Paren(this=_maybe_copy(expression, copy)) SAFE_IDENTIFIER_RE = re.compile(r"^[_a-zA-Z][\w]*$") @@ -4998,29 +5057,20 @@ def values( alias: optional alias columns: Optional list of ordered column names or ordered dictionary of column names to types. If either are provided then an alias is also required. - If a dictionary is provided then the first column of the values will be casted to the expected type - in order to help with type inference. Returns: Values: the Values expression object """ if columns and not alias: raise ValueError("Alias is required when providing columns") - table_alias = ( - TableAlias(this=to_identifier(alias), columns=[to_identifier(x) for x in columns]) - if columns - else TableAlias(this=to_identifier(alias) if alias else None) - ) - expressions = [convert(tup) for tup in values] - if columns and isinstance(columns, dict): - types = list(columns.values()) - expressions[0].set( - "expressions", - [cast(x, types[i]) for i, x in enumerate(expressions[0].expressions)], - ) + return Values( - expressions=expressions, - alias=table_alias, + expressions=[convert(tup) for tup in values], + alias=( + TableAlias(this=to_identifier(alias), columns=[to_identifier(x) for x in columns]) + if columns + else (TableAlias(this=to_identifier(alias)) if alias else None) + ), ) @@ -5068,19 +5118,20 @@ def rename_table(old_name: str | Table, new_name: str | Table) -> AlterTable: ) -def convert(value) -> Expression: +def convert(value: t.Any, copy: bool = False) -> Expression: """Convert a python value into an expression object. Raises an error if a conversion is not possible. Args: - value (Any): a python object + value: A python object. + copy: Whether or not to copy `value` (only applies to Expressions and collections). Returns: - Expression: the equivalent expression object + Expression: the equivalent expression object. """ if isinstance(value, Expression): - return value + return _maybe_copy(value, copy) if isinstance(value, str): return Literal.string(value) if isinstance(value, bool): @@ -5098,13 +5149,13 @@ def convert(value) -> Expression: date_literal = Literal.string(value.strftime("%Y-%m-%d")) return DateStrToDate(this=date_literal) if isinstance(value, tuple): - return Tuple(expressions=[convert(v) for v in value]) + return Tuple(expressions=[convert(v, copy=copy) for v in value]) if isinstance(value, list): - return Array(expressions=[convert(v) for v in value]) + return Array(expressions=[convert(v, copy=copy) for v in value]) if isinstance(value, dict): return Map( - keys=[convert(k) for k in value], - values=[convert(v) for v in value.values()], + keys=[convert(k, copy=copy) for k in value], + values=[convert(v, copy=copy) for v in value.values()], ) raise ValueError(f"Cannot convert {value}") |