summaryrefslogtreecommitdiffstats
path: root/sqlglot/expressions.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/expressions.py')
-rw-r--r--sqlglot/expressions.py163
1 files changed, 91 insertions, 72 deletions
diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py
index b9da4cc..f4aae47 100644
--- a/sqlglot/expressions.py
+++ b/sqlglot/expressions.py
@@ -26,6 +26,7 @@ from sqlglot.helper import (
AutoName,
camel_to_snake_case,
ensure_collection,
+ ensure_list,
seq_get,
split_num_words,
subclasses,
@@ -84,7 +85,7 @@ class Expression(metaclass=_Expression):
key = "expression"
arg_types = {"this": True}
- __slots__ = ("args", "parent", "arg_key", "comments", "_type", "_meta")
+ __slots__ = ("args", "parent", "arg_key", "comments", "_type", "_meta", "_hash")
def __init__(self, **args: t.Any):
self.args: t.Dict[str, t.Any] = args
@@ -93,23 +94,31 @@ class Expression(metaclass=_Expression):
self.comments: t.Optional[t.List[str]] = None
self._type: t.Optional[DataType] = None
self._meta: t.Optional[t.Dict[str, t.Any]] = None
+ self._hash: t.Optional[int] = None
for arg_key, value in self.args.items():
self._set_parent(arg_key, value)
def __eq__(self, other) -> bool:
- return type(self) is type(other) and _norm_args(self) == _norm_args(other)
+ return type(self) is type(other) and hash(self) == hash(other)
- def __hash__(self) -> int:
- return hash(
- (
- self.key,
- tuple(
- (k, tuple(v) if isinstance(v, list) else v) for k, v in _norm_args(self).items()
- ),
- )
+ @property
+ def hashable_args(self) -> t.Any:
+ args = (self.args.get(k) for k in self.arg_types)
+
+ return tuple(
+ (tuple(_norm_arg(a) for a in arg) if arg else None)
+ if type(arg) is list
+ else (_norm_arg(arg) if arg is not None and arg is not False else None)
+ for arg in args
)
+ def __hash__(self) -> int:
+ if self._hash is not None:
+ return self._hash
+
+ return hash((self.__class__, self.hashable_args))
+
@property
def this(self):
"""
@@ -247,9 +256,6 @@ class Expression(metaclass=_Expression):
"""
new = deepcopy(self)
new.parent = self.parent
- for item, parent, _ in new.bfs():
- if isinstance(item, Expression) and parent:
- item.parent = parent
return new
def append(self, arg_key, value):
@@ -277,12 +283,12 @@ class Expression(metaclass=_Expression):
self._set_parent(arg_key, value)
def _set_parent(self, arg_key, value):
- if isinstance(value, Expression):
+ if hasattr(value, "parent"):
value.parent = self
value.arg_key = arg_key
- elif isinstance(value, list):
+ elif type(value) is list:
for v in value:
- if isinstance(v, Expression):
+ if hasattr(v, "parent"):
v.parent = self
v.arg_key = arg_key
@@ -295,6 +301,17 @@ class Expression(metaclass=_Expression):
return self.parent.depth + 1
return 0
+ def iter_expressions(self) -> t.Iterator[t.Tuple[str, Expression]]:
+ """Yields the key and expression for all arguments, exploding list args."""
+ for k, vs in self.args.items():
+ if type(vs) is list:
+ for v in vs:
+ if hasattr(v, "parent"):
+ yield k, v
+ else:
+ if hasattr(vs, "parent"):
+ yield k, vs
+
def find(self, *expression_types: t.Type[E], bfs=True) -> E | None:
"""
Returns the first node in this tree which matches at least one of
@@ -319,7 +336,7 @@ class Expression(metaclass=_Expression):
Returns:
The generator object.
"""
- for expression, _, _ in self.walk(bfs=bfs):
+ for expression, *_ in self.walk(bfs=bfs):
if isinstance(expression, expression_types):
yield expression
@@ -345,6 +362,11 @@ class Expression(metaclass=_Expression):
"""
return self.find_ancestor(Select)
+ @property
+ def same_parent(self):
+ """Returns if the parent is the same class as itself."""
+ return type(self.parent) is self.__class__
+
def root(self) -> Expression:
"""
Returns the root expression of this tree.
@@ -385,10 +407,8 @@ class Expression(metaclass=_Expression):
if prune and prune(self, parent, key):
return
- for k, v in self.args.items():
- for node in ensure_collection(v):
- if isinstance(node, Expression):
- yield from node.dfs(self, k, prune)
+ for k, v in self.iter_expressions():
+ yield from v.dfs(self, k, prune)
def bfs(self, prune=None):
"""
@@ -407,18 +427,15 @@ class Expression(metaclass=_Expression):
if prune and prune(item, parent, key):
continue
- if isinstance(item, Expression):
- for k, v in item.args.items():
- for node in ensure_collection(v):
- if isinstance(node, Expression):
- queue.append((node, item, k))
+ for k, v in item.iter_expressions():
+ queue.append((v, item, k))
def unnest(self):
"""
Returns the first non parenthesis child or self.
"""
expression = self
- while isinstance(expression, Paren):
+ while type(expression) is Paren:
expression = expression.this
return expression
@@ -434,7 +451,7 @@ class Expression(metaclass=_Expression):
"""
Returns unnested operands as a tuple.
"""
- return tuple(arg.unnest() for arg in self.args.values() if arg)
+ return tuple(arg.unnest() for _, arg in self.iter_expressions())
def flatten(self, unnest=True):
"""
@@ -442,8 +459,8 @@ class Expression(metaclass=_Expression):
A AND B AND C -> [A, B, C]
"""
- for node, _, _ in self.dfs(prune=lambda n, p, *_: p and not isinstance(n, self.__class__)):
- if not isinstance(node, self.__class__):
+ for node, _, _ in self.dfs(prune=lambda n, p, *_: p and not type(n) is self.__class__):
+ if not type(node) is self.__class__:
yield node.unnest() if unnest else node
def __str__(self):
@@ -477,7 +494,7 @@ class Expression(metaclass=_Expression):
v._to_s(hide_missing=hide_missing, level=level + 1)
if hasattr(v, "_to_s")
else str(v)
- for v in ensure_collection(vs)
+ for v in ensure_list(vs)
if v is not None
)
for k, vs in self.args.items()
@@ -812,6 +829,10 @@ class Describe(Expression):
arg_types = {"this": True, "kind": False}
+class Pragma(Expression):
+ pass
+
+
class Set(Expression):
arg_types = {"expressions": False}
@@ -1170,6 +1191,7 @@ class Drop(Expression):
"temporary": False,
"materialized": False,
"cascade": False,
+ "constraints": False,
}
@@ -1232,11 +1254,11 @@ class Identifier(Expression):
def quoted(self):
return bool(self.args.get("quoted"))
- def __eq__(self, other):
- return isinstance(other, self.__class__) and _norm_arg(self.this) == _norm_arg(other.this)
-
- def __hash__(self):
- return hash((self.key, self.this.lower()))
+ @property
+ def hashable_args(self) -> t.Any:
+ if self.quoted and any(char.isupper() for char in self.this):
+ return (self.this, self.quoted)
+ return self.this.lower()
@property
def output_name(self):
@@ -1322,15 +1344,9 @@ class Limit(Expression):
class Literal(Condition):
arg_types = {"this": True, "is_string": True}
- def __eq__(self, other):
- return (
- isinstance(other, Literal)
- and self.this == other.this
- and self.args["is_string"] == other.args["is_string"]
- )
-
- def __hash__(self):
- return hash((self.key, self.this, self.args["is_string"]))
+ @property
+ def hashable_args(self) -> t.Any:
+ return (self.this, self.args.get("is_string"))
@classmethod
def number(cls, number) -> Literal:
@@ -1784,7 +1800,7 @@ class Subqueryable(Unionable):
instance = _maybe_copy(self, copy)
return Subquery(
this=instance,
- alias=TableAlias(this=to_identifier(alias)),
+ alias=TableAlias(this=to_identifier(alias)) if alias else None,
)
def limit(self, expression, dialect=None, copy=True, **opts) -> Select:
@@ -2058,6 +2074,7 @@ class Lock(Expression):
class Select(Subqueryable):
arg_types = {
"with": False,
+ "kind": False,
"expressions": False,
"hint": False,
"distinct": False,
@@ -3595,6 +3612,21 @@ class Initcap(Func):
pass
+class JSONKeyValue(Expression):
+ arg_types = {"this": True, "expression": True}
+
+
+class JSONObject(Func):
+ arg_types = {
+ "expressions": False,
+ "null_handling": False,
+ "unique_keys": False,
+ "return_type": False,
+ "format_json": False,
+ "encoding": False,
+ }
+
+
class JSONBContains(Binary):
_sql_names = ["JSONB_CONTAINS"]
@@ -3766,8 +3798,10 @@ class RegexpILike(Func):
arg_types = {"this": True, "expression": True, "flag": False}
+# https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.split.html
+# limit is the number of times a pattern is applied
class RegexpSplit(Func):
- arg_types = {"this": True, "expression": True}
+ arg_types = {"this": True, "expression": True, "limit": False}
class Repeat(Func):
@@ -3967,25 +4001,8 @@ class When(Func):
arg_types = {"matched": True, "source": False, "condition": False, "then": True}
-def _norm_args(expression):
- args = {}
-
- for k, arg in expression.args.items():
- if isinstance(arg, list):
- arg = [_norm_arg(a) for a in arg]
- if not arg:
- arg = None
- else:
- arg = _norm_arg(arg)
-
- if arg is not None and arg is not False:
- args[k] = arg
-
- return args
-
-
def _norm_arg(arg):
- return arg.lower() if isinstance(arg, str) else arg
+ return arg.lower() if type(arg) is str else arg
ALL_FUNCTIONS = subclasses(__name__, Func, (AggFunc, Anonymous, Func))
@@ -4512,7 +4529,7 @@ def to_identifier(name, quoted=None):
elif isinstance(name, str):
identifier = Identifier(
this=name,
- quoted=not re.match(SAFE_IDENTIFIER_RE, name) if quoted is None else quoted,
+ quoted=not SAFE_IDENTIFIER_RE.match(name) if quoted is None else quoted,
)
else:
raise ValueError(f"Name needs to be a string or an Identifier, got: {name.__class__}")
@@ -4586,8 +4603,7 @@ def to_column(sql_path: str | Column, **kwargs) -> Column:
return sql_path
if not isinstance(sql_path, str):
raise ValueError(f"Invalid type provided for column: {type(sql_path)}")
- table_name, column_name = (to_identifier(x) for x in split_num_words(sql_path, ".", 2))
- return Column(this=column_name, table=table_name, **kwargs)
+ return column(*reversed(sql_path.split(".")), **kwargs) # type: ignore
def alias_(
@@ -4672,7 +4688,8 @@ def subquery(expression, alias=None, dialect=None, **opts):
def column(
col: str | Identifier,
table: t.Optional[str | Identifier] = None,
- schema: t.Optional[str | Identifier] = None,
+ db: t.Optional[str | Identifier] = None,
+ catalog: t.Optional[str | Identifier] = None,
quoted: t.Optional[bool] = None,
) -> Column:
"""
@@ -4681,7 +4698,8 @@ def column(
Args:
col: column name
table: table name
- schema: schema name
+ db: db name
+ catalog: catalog name
quoted: whether or not to force quote each part
Returns:
Column: column instance
@@ -4689,7 +4707,8 @@ def column(
return Column(
this=to_identifier(col, quoted=quoted),
table=to_identifier(table, quoted=quoted),
- schema=to_identifier(schema, quoted=quoted),
+ db=to_identifier(db, quoted=quoted),
+ catalog=to_identifier(catalog, quoted=quoted),
)
@@ -4864,7 +4883,7 @@ def replace_children(expression, fun, *args, **kwargs):
Replace children of an expression with the result of a lambda fun(child) -> exp.
"""
for k, v in expression.args.items():
- is_list_arg = isinstance(v, list)
+ is_list_arg = type(v) is list
child_nodes = v if is_list_arg else [v]
new_child_nodes = []