summaryrefslogtreecommitdiffstats
path: root/sqlglot/expressions.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/expressions.py')
-rw-r--r--sqlglot/expressions.py68
1 files changed, 43 insertions, 25 deletions
diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py
index 599c7db..8cdacce 100644
--- a/sqlglot/expressions.py
+++ b/sqlglot/expressions.py
@@ -213,21 +213,23 @@ class Expression(metaclass=_Expression):
"""
return self.find_ancestor(Select)
- def walk(self, bfs=True):
+ def walk(self, bfs=True, prune=None):
"""
Returns a generator object which visits all nodes in this tree.
Args:
bfs (bool): if set to True the BFS traversal order will be applied,
otherwise the DFS traversal will be used instead.
+ prune ((node, parent, arg_key) -> bool): callable that returns True if
+ the generator should stop traversing this branch of the tree.
Returns:
the generator object.
"""
if bfs:
- yield from self.bfs()
+ yield from self.bfs(prune=prune)
else:
- yield from self.dfs()
+ yield from self.dfs(prune=prune)
def dfs(self, parent=None, key=None, prune=None):
"""
@@ -506,6 +508,10 @@ class DerivedTable(Expression):
return [select.alias_or_name for select in self.selects]
+class UDTF(DerivedTable):
+ pass
+
+
class Annotation(Expression):
arg_types = {
"this": True,
@@ -652,7 +658,13 @@ class Delete(Expression):
class Drop(Expression):
- arg_types = {"this": False, "kind": False, "exists": False}
+ arg_types = {
+ "this": False,
+ "kind": False,
+ "exists": False,
+ "temporary": False,
+ "materialized": False,
+ }
class Filter(Expression):
@@ -827,7 +839,7 @@ class Join(Expression):
return join
-class Lateral(DerivedTable):
+class Lateral(UDTF):
arg_types = {"this": True, "outer": False, "alias": False}
@@ -915,6 +927,14 @@ class LanguageProperty(Property):
pass
+class ExecuteAsProperty(Property):
+ pass
+
+
+class VolatilityProperty(Property):
+ arg_types = {"this": True}
+
+
class Properties(Expression):
arg_types = {"expressions": True}
@@ -1098,7 +1118,7 @@ class Intersect(Union):
pass
-class Unnest(DerivedTable):
+class Unnest(UDTF):
arg_types = {
"expressions": True,
"ordinality": False,
@@ -1116,8 +1136,12 @@ class Update(Expression):
}
-class Values(Expression):
- arg_types = {"expressions": True}
+class Values(UDTF):
+ arg_types = {
+ "expressions": True,
+ "ordinality": False,
+ "alias": False,
+ }
class Var(Expression):
@@ -2033,23 +2057,17 @@ class Func(Condition):
@classmethod
def from_arg_list(cls, args):
- args_num = len(args)
-
- all_arg_keys = list(cls.arg_types)
- # If this function supports variable length argument treat the last argument as such.
- non_var_len_arg_keys = all_arg_keys[:-1] if cls.is_var_len_args else all_arg_keys
-
- args_dict = {}
- arg_idx = 0
- for arg_key in non_var_len_arg_keys:
- if arg_idx >= args_num:
- break
- if args[arg_idx] is not None:
- args_dict[arg_key] = args[arg_idx]
- arg_idx += 1
-
- if arg_idx < args_num and cls.is_var_len_args:
- args_dict[all_arg_keys[-1]] = args[arg_idx:]
+ if cls.is_var_len_args:
+ all_arg_keys = list(cls.arg_types)
+ # If this function supports variable length argument treat the last argument as such.
+ non_var_len_arg_keys = all_arg_keys[:-1] if cls.is_var_len_args else all_arg_keys
+ num_non_var = len(non_var_len_arg_keys)
+
+ args_dict = {arg_key: arg for arg, arg_key in zip(args, non_var_len_arg_keys)}
+ args_dict[all_arg_keys[-1]] = args[num_non_var:]
+ else:
+ args_dict = {arg_key: arg for arg, arg_key in zip(args, cls.arg_types)}
+
return cls(**args_dict)
@classmethod