From 90150543f9314be683d22a16339effd774192f6d Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Thu, 22 Sep 2022 06:31:28 +0200 Subject: Merging upstream version 6.1.1. Signed-off-by: Daniel Baumann --- sqlglot/expressions.py | 169 +++++++++++++++++++++++++++++++++---------------- 1 file changed, 113 insertions(+), 56 deletions(-) (limited to 'sqlglot/expressions.py') diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 7acc63d..b983bf9 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -47,10 +47,7 @@ class Expression(metaclass=_Expression): return hash( ( self.key, - tuple( - (k, tuple(v) if isinstance(v, list) else v) - for k, v in _norm_args(self).items() - ), + tuple((k, tuple(v) if isinstance(v, list) else v) for k, v in _norm_args(self).items()), ) ) @@ -116,9 +113,22 @@ class Expression(metaclass=_Expression): item.parent = parent return new + def append(self, arg_key, value): + """ + Appends value to arg_key if it's a list or sets it as a new list. + + Args: + arg_key (str): name of the list expression arg + value (Any): value to append to the list + """ + if not isinstance(self.args.get(arg_key), list): + self.args[arg_key] = [] + self.args[arg_key].append(value) + self._set_parent(arg_key, value) + def set(self, arg_key, value): """ - Sets `arg` to `value`. + Sets `arg_key` to `value`. Args: arg_key (str): name of the expression arg @@ -267,6 +277,14 @@ class Expression(metaclass=_Expression): expression = expression.this return expression + def unalias(self): + """ + Returns the inner expression if this is an Alias. + """ + if isinstance(self, Alias): + return self.this + return self + def unnest_operands(self): """ Returns unnested operands as a tuple. @@ -279,9 +297,7 @@ class Expression(metaclass=_Expression): A AND B AND C -> [A, B, C] """ - for node, _, _ in self.dfs( - prune=lambda n, p, *_: p and not isinstance(n, self.__class__) - ): + for node, _, _ in self.dfs(prune=lambda n, p, *_: p and not isinstance(n, self.__class__)): if not isinstance(node, self.__class__): yield node.unnest() if unnest else node @@ -314,9 +330,7 @@ class Expression(metaclass=_Expression): args = { k: ", ".join( - v.to_s(hide_missing=hide_missing, level=level + 1) - if hasattr(v, "to_s") - else str(v) + v.to_s(hide_missing=hide_missing, level=level + 1) if hasattr(v, "to_s") else str(v) for v in ensure_list(vs) if v is not None ) @@ -354,9 +368,7 @@ class Expression(metaclass=_Expression): new_node.parent = node.parent return new_node - replace_children( - new_node, lambda child: child.transform(fun, *args, copy=False, **kwargs) - ) + replace_children(new_node, lambda child: child.transform(fun, *args, copy=False, **kwargs)) return new_node def replace(self, expression): @@ -546,6 +558,10 @@ class BitString(Condition): pass +class HexString(Condition): + pass + + class Column(Condition): arg_types = {"this": True, "table": False} @@ -566,35 +582,44 @@ class ColumnConstraint(Expression): arg_types = {"this": False, "kind": True} -class AutoIncrementColumnConstraint(Expression): +class ColumnConstraintKind(Expression): pass -class CheckColumnConstraint(Expression): +class AutoIncrementColumnConstraint(ColumnConstraintKind): pass -class CollateColumnConstraint(Expression): +class CheckColumnConstraint(ColumnConstraintKind): pass -class CommentColumnConstraint(Expression): +class CollateColumnConstraint(ColumnConstraintKind): pass -class DefaultColumnConstraint(Expression): +class CommentColumnConstraint(ColumnConstraintKind): pass -class NotNullColumnConstraint(Expression): +class DefaultColumnConstraint(ColumnConstraintKind): pass -class PrimaryKeyColumnConstraint(Expression): +class GeneratedAsIdentityColumnConstraint(ColumnConstraintKind): + # this: True -> ALWAYS, this: False -> BY DEFAULT + arg_types = {"this": True, "expression": False} + + +class NotNullColumnConstraint(ColumnConstraintKind): pass -class UniqueColumnConstraint(Expression): +class PrimaryKeyColumnConstraint(ColumnConstraintKind): + pass + + +class UniqueColumnConstraint(ColumnConstraintKind): pass @@ -651,9 +676,7 @@ class Identifier(Expression): return bool(self.args.get("quoted")) def __eq__(self, other): - return isinstance(other, self.__class__) and _norm_arg(self.this) == _norm_arg( - other.this - ) + return isinstance(other, self.__class__) and _norm_arg(self.this) == _norm_arg(other.this) def __hash__(self): return hash((self.key, self.this.lower())) @@ -709,9 +732,7 @@ class Literal(Condition): def __eq__(self, other): return ( - isinstance(other, Literal) - and self.this == other.this - and self.args["is_string"] == other.args["is_string"] + isinstance(other, Literal) and self.this == other.this and self.args["is_string"] == other.args["is_string"] ) def __hash__(self): @@ -733,6 +754,7 @@ class Join(Expression): "side": False, "kind": False, "using": False, + "natural": False, } @property @@ -743,6 +765,10 @@ class Join(Expression): def side(self): return self.text("side").upper() + @property + def alias_or_name(self): + return self.this.alias_or_name + def on(self, *expressions, append=True, dialect=None, copy=True, **opts): """ Append to or set the ON expressions. @@ -873,10 +899,6 @@ class Reference(Expression): arg_types = {"this": True, "expressions": True} -class Table(Expression): - arg_types = {"this": True, "db": False, "catalog": False} - - class Tuple(Expression): arg_types = {"expressions": False} @@ -986,6 +1008,16 @@ QUERY_MODIFIERS = { } +class Table(Expression): + arg_types = { + "this": True, + "db": False, + "catalog": False, + "laterals": False, + "joins": False, + } + + class Union(Subqueryable, Expression): arg_types = { "with": False, @@ -1396,7 +1428,9 @@ class Select(Subqueryable, Expression): join.this.replace(join.this.subquery()) if join_type: - side, kind = maybe_parse(join_type, into="JOIN_TYPE", **parse_args) + natural, side, kind = maybe_parse(join_type, into="JOIN_TYPE", **parse_args) + if natural: + join.set("natural", True) if side: join.set("side", side.text) if kind: @@ -1529,10 +1563,7 @@ class Select(Subqueryable, Expression): properties_expression = None if properties: properties_str = " ".join( - [ - f"{k} = '{v}'" if isinstance(v, str) else f"{k} = {v}" - for k, v in properties.items() - ] + [f"{k} = '{v}'" if isinstance(v, str) else f"{k} = {v}" for k, v in properties.items()] ) properties_expression = maybe_parse( properties_str, @@ -1654,6 +1685,7 @@ class DataType(Expression): DECIMAL = auto() BOOLEAN = auto() JSON = auto() + INTERVAL = auto() TIMESTAMP = auto() TIMESTAMPTZ = auto() DATE = auto() @@ -1662,15 +1694,19 @@ class DataType(Expression): MAP = auto() UUID = auto() GEOGRAPHY = auto() + GEOMETRY = auto() STRUCT = auto() NULLABLE = auto() + HLLSKETCH = auto() + SUPER = auto() + SERIAL = auto() + SMALLSERIAL = auto() + BIGSERIAL = auto() @classmethod def build(cls, dtype, **kwargs): return DataType( - this=dtype - if isinstance(dtype, DataType.Type) - else DataType.Type[dtype.upper()], + this=dtype if isinstance(dtype, DataType.Type) else DataType.Type[dtype.upper()], **kwargs, ) @@ -1798,6 +1834,14 @@ class Like(Binary, Predicate): pass +class SimilarTo(Binary, Predicate): + pass + + +class Distance(Binary): + pass + + class LT(Binary, Predicate): pass @@ -1899,6 +1943,10 @@ class IgnoreNulls(Expression): pass +class RespectNulls(Expression): + pass + + # Functions class Func(Condition): """ @@ -1924,9 +1972,7 @@ class Func(Condition): all_arg_keys = list(cls.arg_types) # If this function supports variable length argument treat the last argument as such. - non_var_len_arg_keys = ( - all_arg_keys[:-1] if cls.is_var_len_args else all_arg_keys - ) + non_var_len_arg_keys = all_arg_keys[:-1] if cls.is_var_len_args else all_arg_keys args_dict = {} arg_idx = 0 @@ -1944,9 +1990,7 @@ class Func(Condition): @classmethod def sql_names(cls): if cls is Func: - raise NotImplementedError( - "SQL name is only supported by concrete function implementations" - ) + raise NotImplementedError("SQL name is only supported by concrete function implementations") if not hasattr(cls, "_sql_names"): cls._sql_names = [camel_to_snake_case(cls.__name__)] return cls._sql_names @@ -2178,6 +2222,10 @@ class Greatest(Func): is_var_len_args = True +class GroupConcat(Func): + arg_types = {"this": True, "separator": False} + + class If(Func): arg_types = {"this": True, "true": True, "false": False} @@ -2274,6 +2322,10 @@ class Quantile(AggFunc): arg_types = {"this": True, "quantile": True} +class ApproxQuantile(Quantile): + pass + + class Reduce(Func): arg_types = {"this": True, "initial": True, "merge": True, "finish": True} @@ -2306,8 +2358,10 @@ class Split(Func): arg_types = {"this": True, "expression": True} +# Start may be omitted in the case of postgres +# https://www.postgresql.org/docs/9.1/functions-string.html @ Table 9-6 class Substring(Func): - arg_types = {"this": True, "start": True, "length": False} + arg_types = {"this": True, "start": False, "length": False} class StrPosition(Func): @@ -2379,6 +2433,15 @@ class TimeStrToUnix(Func): pass +class Trim(Func): + arg_types = { + "this": True, + "position": False, + "expression": False, + "collation": False, + } + + class TsOrDsAdd(Func, TimeUnit): arg_types = {"this": True, "expression": True, "unit": False} @@ -2455,9 +2518,7 @@ def _all_functions(): obj for _, obj in inspect.getmembers( sys.modules[__name__], - lambda obj: inspect.isclass(obj) - and issubclass(obj, Func) - and obj not in (AggFunc, Anonymous, Func), + lambda obj: inspect.isclass(obj) and issubclass(obj, Func) and obj not in (AggFunc, Anonymous, Func), ) ] @@ -2633,9 +2694,7 @@ def _apply_conjunction_builder( def _combine(expressions, operator, dialect=None, **opts): - expressions = [ - condition(expression, dialect=dialect, **opts) for expression in expressions - ] + expressions = [condition(expression, dialect=dialect, **opts) for expression in expressions] this = expressions[0] if expressions[1:]: this = _wrap_operator(this) @@ -2809,9 +2868,7 @@ def to_identifier(alias, quoted=None): quoted = not re.match(SAFE_IDENTIFIER_RE, alias) identifier = Identifier(this=alias, quoted=quoted) else: - raise ValueError( - f"Alias needs to be a string or an Identifier, got: {alias.__class__}" - ) + raise ValueError(f"Alias needs to be a string or an Identifier, got: {alias.__class__}") return identifier -- cgit v1.2.3