diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-04-03 07:31:54 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-04-03 07:31:54 +0000 |
commit | b38d717d5933fdae3fe85c87df7aee9a251fb58e (patch) | |
tree | 6db21a44ffea4c832dcab29688bfaf1c1dc124f9 /sqlglot/expressions.py | |
parent | Releasing debian version 11.4.1-1. (diff) | |
download | sqlglot-b38d717d5933fdae3fe85c87df7aee9a251fb58e.tar.xz sqlglot-b38d717d5933fdae3fe85c87df7aee9a251fb58e.zip |
Merging upstream version 11.4.5.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/expressions.py')
-rw-r--r-- | sqlglot/expressions.py | 163 |
1 files changed, 91 insertions, 72 deletions
diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index b9da4cc..f4aae47 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -26,6 +26,7 @@ from sqlglot.helper import ( AutoName, camel_to_snake_case, ensure_collection, + ensure_list, seq_get, split_num_words, subclasses, @@ -84,7 +85,7 @@ class Expression(metaclass=_Expression): key = "expression" arg_types = {"this": True} - __slots__ = ("args", "parent", "arg_key", "comments", "_type", "_meta") + __slots__ = ("args", "parent", "arg_key", "comments", "_type", "_meta", "_hash") def __init__(self, **args: t.Any): self.args: t.Dict[str, t.Any] = args @@ -93,23 +94,31 @@ class Expression(metaclass=_Expression): self.comments: t.Optional[t.List[str]] = None self._type: t.Optional[DataType] = None self._meta: t.Optional[t.Dict[str, t.Any]] = None + self._hash: t.Optional[int] = None for arg_key, value in self.args.items(): self._set_parent(arg_key, value) def __eq__(self, other) -> bool: - return type(self) is type(other) and _norm_args(self) == _norm_args(other) + return type(self) is type(other) and hash(self) == hash(other) - def __hash__(self) -> int: - return hash( - ( - self.key, - tuple( - (k, tuple(v) if isinstance(v, list) else v) for k, v in _norm_args(self).items() - ), - ) + @property + def hashable_args(self) -> t.Any: + args = (self.args.get(k) for k in self.arg_types) + + return tuple( + (tuple(_norm_arg(a) for a in arg) if arg else None) + if type(arg) is list + else (_norm_arg(arg) if arg is not None and arg is not False else None) + for arg in args ) + def __hash__(self) -> int: + if self._hash is not None: + return self._hash + + return hash((self.__class__, self.hashable_args)) + @property def this(self): """ @@ -247,9 +256,6 @@ class Expression(metaclass=_Expression): """ new = deepcopy(self) new.parent = self.parent - for item, parent, _ in new.bfs(): - if isinstance(item, Expression) and parent: - item.parent = parent return new def append(self, arg_key, value): @@ -277,12 +283,12 @@ class Expression(metaclass=_Expression): self._set_parent(arg_key, value) def _set_parent(self, arg_key, value): - if isinstance(value, Expression): + if hasattr(value, "parent"): value.parent = self value.arg_key = arg_key - elif isinstance(value, list): + elif type(value) is list: for v in value: - if isinstance(v, Expression): + if hasattr(v, "parent"): v.parent = self v.arg_key = arg_key @@ -295,6 +301,17 @@ class Expression(metaclass=_Expression): return self.parent.depth + 1 return 0 + def iter_expressions(self) -> t.Iterator[t.Tuple[str, Expression]]: + """Yields the key and expression for all arguments, exploding list args.""" + for k, vs in self.args.items(): + if type(vs) is list: + for v in vs: + if hasattr(v, "parent"): + yield k, v + else: + if hasattr(vs, "parent"): + yield k, vs + def find(self, *expression_types: t.Type[E], bfs=True) -> E | None: """ Returns the first node in this tree which matches at least one of @@ -319,7 +336,7 @@ class Expression(metaclass=_Expression): Returns: The generator object. """ - for expression, _, _ in self.walk(bfs=bfs): + for expression, *_ in self.walk(bfs=bfs): if isinstance(expression, expression_types): yield expression @@ -345,6 +362,11 @@ class Expression(metaclass=_Expression): """ return self.find_ancestor(Select) + @property + def same_parent(self): + """Returns if the parent is the same class as itself.""" + return type(self.parent) is self.__class__ + def root(self) -> Expression: """ Returns the root expression of this tree. @@ -385,10 +407,8 @@ class Expression(metaclass=_Expression): if prune and prune(self, parent, key): return - for k, v in self.args.items(): - for node in ensure_collection(v): - if isinstance(node, Expression): - yield from node.dfs(self, k, prune) + for k, v in self.iter_expressions(): + yield from v.dfs(self, k, prune) def bfs(self, prune=None): """ @@ -407,18 +427,15 @@ class Expression(metaclass=_Expression): if prune and prune(item, parent, key): continue - if isinstance(item, Expression): - for k, v in item.args.items(): - for node in ensure_collection(v): - if isinstance(node, Expression): - queue.append((node, item, k)) + for k, v in item.iter_expressions(): + queue.append((v, item, k)) def unnest(self): """ Returns the first non parenthesis child or self. """ expression = self - while isinstance(expression, Paren): + while type(expression) is Paren: expression = expression.this return expression @@ -434,7 +451,7 @@ class Expression(metaclass=_Expression): """ Returns unnested operands as a tuple. """ - return tuple(arg.unnest() for arg in self.args.values() if arg) + return tuple(arg.unnest() for _, arg in self.iter_expressions()) def flatten(self, unnest=True): """ @@ -442,8 +459,8 @@ class Expression(metaclass=_Expression): A AND B AND C -> [A, B, C] """ - for node, _, _ in self.dfs(prune=lambda n, p, *_: p and not isinstance(n, self.__class__)): - if not isinstance(node, self.__class__): + for node, _, _ in self.dfs(prune=lambda n, p, *_: p and not type(n) is self.__class__): + if not type(node) is self.__class__: yield node.unnest() if unnest else node def __str__(self): @@ -477,7 +494,7 @@ class Expression(metaclass=_Expression): v._to_s(hide_missing=hide_missing, level=level + 1) if hasattr(v, "_to_s") else str(v) - for v in ensure_collection(vs) + for v in ensure_list(vs) if v is not None ) for k, vs in self.args.items() @@ -812,6 +829,10 @@ class Describe(Expression): arg_types = {"this": True, "kind": False} +class Pragma(Expression): + pass + + class Set(Expression): arg_types = {"expressions": False} @@ -1170,6 +1191,7 @@ class Drop(Expression): "temporary": False, "materialized": False, "cascade": False, + "constraints": False, } @@ -1232,11 +1254,11 @@ class Identifier(Expression): def quoted(self): return bool(self.args.get("quoted")) - def __eq__(self, other): - return isinstance(other, self.__class__) and _norm_arg(self.this) == _norm_arg(other.this) - - def __hash__(self): - return hash((self.key, self.this.lower())) + @property + def hashable_args(self) -> t.Any: + if self.quoted and any(char.isupper() for char in self.this): + return (self.this, self.quoted) + return self.this.lower() @property def output_name(self): @@ -1322,15 +1344,9 @@ class Limit(Expression): class Literal(Condition): arg_types = {"this": True, "is_string": True} - def __eq__(self, other): - return ( - isinstance(other, Literal) - and self.this == other.this - and self.args["is_string"] == other.args["is_string"] - ) - - def __hash__(self): - return hash((self.key, self.this, self.args["is_string"])) + @property + def hashable_args(self) -> t.Any: + return (self.this, self.args.get("is_string")) @classmethod def number(cls, number) -> Literal: @@ -1784,7 +1800,7 @@ class Subqueryable(Unionable): instance = _maybe_copy(self, copy) return Subquery( this=instance, - alias=TableAlias(this=to_identifier(alias)), + alias=TableAlias(this=to_identifier(alias)) if alias else None, ) def limit(self, expression, dialect=None, copy=True, **opts) -> Select: @@ -2058,6 +2074,7 @@ class Lock(Expression): class Select(Subqueryable): arg_types = { "with": False, + "kind": False, "expressions": False, "hint": False, "distinct": False, @@ -3595,6 +3612,21 @@ class Initcap(Func): pass +class JSONKeyValue(Expression): + arg_types = {"this": True, "expression": True} + + +class JSONObject(Func): + arg_types = { + "expressions": False, + "null_handling": False, + "unique_keys": False, + "return_type": False, + "format_json": False, + "encoding": False, + } + + class JSONBContains(Binary): _sql_names = ["JSONB_CONTAINS"] @@ -3766,8 +3798,10 @@ class RegexpILike(Func): arg_types = {"this": True, "expression": True, "flag": False} +# https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.split.html +# limit is the number of times a pattern is applied class RegexpSplit(Func): - arg_types = {"this": True, "expression": True} + arg_types = {"this": True, "expression": True, "limit": False} class Repeat(Func): @@ -3967,25 +4001,8 @@ class When(Func): arg_types = {"matched": True, "source": False, "condition": False, "then": True} -def _norm_args(expression): - args = {} - - for k, arg in expression.args.items(): - if isinstance(arg, list): - arg = [_norm_arg(a) for a in arg] - if not arg: - arg = None - else: - arg = _norm_arg(arg) - - if arg is not None and arg is not False: - args[k] = arg - - return args - - def _norm_arg(arg): - return arg.lower() if isinstance(arg, str) else arg + return arg.lower() if type(arg) is str else arg ALL_FUNCTIONS = subclasses(__name__, Func, (AggFunc, Anonymous, Func)) @@ -4512,7 +4529,7 @@ def to_identifier(name, quoted=None): elif isinstance(name, str): identifier = Identifier( this=name, - quoted=not re.match(SAFE_IDENTIFIER_RE, name) if quoted is None else quoted, + quoted=not SAFE_IDENTIFIER_RE.match(name) if quoted is None else quoted, ) else: raise ValueError(f"Name needs to be a string or an Identifier, got: {name.__class__}") @@ -4586,8 +4603,7 @@ def to_column(sql_path: str | Column, **kwargs) -> Column: return sql_path if not isinstance(sql_path, str): raise ValueError(f"Invalid type provided for column: {type(sql_path)}") - table_name, column_name = (to_identifier(x) for x in split_num_words(sql_path, ".", 2)) - return Column(this=column_name, table=table_name, **kwargs) + return column(*reversed(sql_path.split(".")), **kwargs) # type: ignore def alias_( @@ -4672,7 +4688,8 @@ def subquery(expression, alias=None, dialect=None, **opts): def column( col: str | Identifier, table: t.Optional[str | Identifier] = None, - schema: t.Optional[str | Identifier] = None, + db: t.Optional[str | Identifier] = None, + catalog: t.Optional[str | Identifier] = None, quoted: t.Optional[bool] = None, ) -> Column: """ @@ -4681,7 +4698,8 @@ def column( Args: col: column name table: table name - schema: schema name + db: db name + catalog: catalog name quoted: whether or not to force quote each part Returns: Column: column instance @@ -4689,7 +4707,8 @@ def column( return Column( this=to_identifier(col, quoted=quoted), table=to_identifier(table, quoted=quoted), - schema=to_identifier(schema, quoted=quoted), + db=to_identifier(db, quoted=quoted), + catalog=to_identifier(catalog, quoted=quoted), ) @@ -4864,7 +4883,7 @@ def replace_children(expression, fun, *args, **kwargs): Replace children of an expression with the result of a lambda fun(child) -> exp. """ for k, v in expression.args.items(): - is_list_arg = isinstance(v, list) + is_list_arg = type(v) is list child_nodes = v if is_list_arg else [v] new_child_nodes = [] |