From 7b29f6168bf9fcb2d886447066a9bb51675e5665 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Tue, 4 Oct 2022 11:37:14 +0200 Subject: Merging upstream version 6.2.8. Signed-off-by: Daniel Baumann --- sqlglot/expressions.py | 68 +++++++++++++++++++++++++++++++------------------- 1 file changed, 43 insertions(+), 25 deletions(-) (limited to 'sqlglot/expressions.py') diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 599c7db..8cdacce 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -213,21 +213,23 @@ class Expression(metaclass=_Expression): """ return self.find_ancestor(Select) - def walk(self, bfs=True): + def walk(self, bfs=True, prune=None): """ Returns a generator object which visits all nodes in this tree. Args: bfs (bool): if set to True the BFS traversal order will be applied, otherwise the DFS traversal will be used instead. + prune ((node, parent, arg_key) -> bool): callable that returns True if + the generator should stop traversing this branch of the tree. Returns: the generator object. """ if bfs: - yield from self.bfs() + yield from self.bfs(prune=prune) else: - yield from self.dfs() + yield from self.dfs(prune=prune) def dfs(self, parent=None, key=None, prune=None): """ @@ -506,6 +508,10 @@ class DerivedTable(Expression): return [select.alias_or_name for select in self.selects] +class UDTF(DerivedTable): + pass + + class Annotation(Expression): arg_types = { "this": True, @@ -652,7 +658,13 @@ class Delete(Expression): class Drop(Expression): - arg_types = {"this": False, "kind": False, "exists": False} + arg_types = { + "this": False, + "kind": False, + "exists": False, + "temporary": False, + "materialized": False, + } class Filter(Expression): @@ -827,7 +839,7 @@ class Join(Expression): return join -class Lateral(DerivedTable): +class Lateral(UDTF): arg_types = {"this": True, "outer": False, "alias": False} @@ -915,6 +927,14 @@ class LanguageProperty(Property): pass +class ExecuteAsProperty(Property): + pass + + +class VolatilityProperty(Property): + arg_types = {"this": True} + + class Properties(Expression): arg_types = {"expressions": True} @@ -1098,7 +1118,7 @@ class Intersect(Union): pass -class Unnest(DerivedTable): +class Unnest(UDTF): arg_types = { "expressions": True, "ordinality": False, @@ -1116,8 +1136,12 @@ class Update(Expression): } -class Values(Expression): - arg_types = {"expressions": True} +class Values(UDTF): + arg_types = { + "expressions": True, + "ordinality": False, + "alias": False, + } class Var(Expression): @@ -2033,23 +2057,17 @@ class Func(Condition): @classmethod def from_arg_list(cls, args): - args_num = len(args) - - all_arg_keys = list(cls.arg_types) - # If this function supports variable length argument treat the last argument as such. - non_var_len_arg_keys = all_arg_keys[:-1] if cls.is_var_len_args else all_arg_keys - - args_dict = {} - arg_idx = 0 - for arg_key in non_var_len_arg_keys: - if arg_idx >= args_num: - break - if args[arg_idx] is not None: - args_dict[arg_key] = args[arg_idx] - arg_idx += 1 - - if arg_idx < args_num and cls.is_var_len_args: - args_dict[all_arg_keys[-1]] = args[arg_idx:] + if cls.is_var_len_args: + all_arg_keys = list(cls.arg_types) + # If this function supports variable length argument treat the last argument as such. + non_var_len_arg_keys = all_arg_keys[:-1] if cls.is_var_len_args else all_arg_keys + num_non_var = len(non_var_len_arg_keys) + + args_dict = {arg_key: arg for arg, arg_key in zip(args, non_var_len_arg_keys)} + args_dict[all_arg_keys[-1]] = args[num_non_var:] + else: + args_dict = {arg_key: arg for arg, arg_key in zip(args, cls.arg_types)} + return cls(**args_dict) @classmethod -- cgit v1.2.3