From b38d717d5933fdae3fe85c87df7aee9a251fb58e Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Mon, 3 Apr 2023 09:31:54 +0200 Subject: Merging upstream version 11.4.5. Signed-off-by: Daniel Baumann --- sqlglot/diff.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) (limited to 'sqlglot/diff.py') diff --git a/sqlglot/diff.py b/sqlglot/diff.py index dddb9ad..86665e0 100644 --- a/sqlglot/diff.py +++ b/sqlglot/diff.py @@ -12,7 +12,7 @@ from dataclasses import dataclass from heapq import heappop, heappush from sqlglot import Dialect, expressions as exp -from sqlglot.helper import ensure_collection +from sqlglot.helper import ensure_list @dataclass(frozen=True) @@ -151,8 +151,8 @@ class ChangeDistiller: self._source = source self._target = target - self._source_index = {id(n[0]): n[0] for n in source.bfs()} - self._target_index = {id(n[0]): n[0] for n in target.bfs()} + self._source_index = {id(n): n for n, *_ in self._source.bfs()} + self._target_index = {id(n): n for n, *_ in self._target.bfs()} self._unmatched_source_nodes = set(self._source_index) - set(pre_matched_nodes) self._unmatched_target_nodes = set(self._target_index) - set(pre_matched_nodes.values()) self._bigram_histo_cache: t.Dict[int, t.DefaultDict[str, int]] = {} @@ -199,10 +199,10 @@ class ChangeDistiller: matching_set = leaves_matching_set.copy() ordered_unmatched_source_nodes = { - id(n[0]): None for n in self._source.bfs() if id(n[0]) in self._unmatched_source_nodes + id(n): None for n, *_ in self._source.bfs() if id(n) in self._unmatched_source_nodes } ordered_unmatched_target_nodes = { - id(n[0]): None for n in self._target.bfs() if id(n[0]) in self._unmatched_target_nodes + id(n): None for n, *_ in self._target.bfs() if id(n) in self._unmatched_target_nodes } for source_node_id in ordered_unmatched_source_nodes: @@ -304,18 +304,18 @@ class ChangeDistiller: def _get_leaves(expression: exp.Expression) -> t.Iterator[exp.Expression]: has_child_exprs = False - for a in expression.args.values(): - for node in ensure_collection(a): - if isinstance(node, exp.Expression): - has_child_exprs = True - yield from _get_leaves(node) + for _, node in expression.iter_expressions(): + has_child_exprs = True + yield from _get_leaves(node) if not has_child_exprs: yield expression def _is_same_type(source: exp.Expression, target: exp.Expression) -> bool: - if type(source) is type(target): + if type(source) is type(target) and ( + not isinstance(source, exp.Identifier) or type(source.parent) is type(target.parent) + ): if isinstance(source, exp.Join): return source.args.get("side") == target.args.get("side") @@ -331,7 +331,7 @@ def _expression_only_args(expression: exp.Expression) -> t.List[exp.Expression]: args: t.List[t.Union[exp.Expression, t.List]] = [] if expression: for a in expression.args.values(): - args.extend(ensure_collection(a)) + args.extend(ensure_list(a)) return [a for a in args if isinstance(a, exp.Expression)] -- cgit v1.2.3