summaryrefslogtreecommitdiffstats
path: root/sqlglot/diff.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/diff.py')
-rw-r--r--sqlglot/diff.py24
1 files changed, 12 insertions, 12 deletions
diff --git a/sqlglot/diff.py b/sqlglot/diff.py
index dddb9ad..86665e0 100644
--- a/sqlglot/diff.py
+++ b/sqlglot/diff.py
@@ -12,7 +12,7 @@ from dataclasses import dataclass
from heapq import heappop, heappush
from sqlglot import Dialect, expressions as exp
-from sqlglot.helper import ensure_collection
+from sqlglot.helper import ensure_list
@dataclass(frozen=True)
@@ -151,8 +151,8 @@ class ChangeDistiller:
self._source = source
self._target = target
- self._source_index = {id(n[0]): n[0] for n in source.bfs()}
- self._target_index = {id(n[0]): n[0] for n in target.bfs()}
+ self._source_index = {id(n): n for n, *_ in self._source.bfs()}
+ self._target_index = {id(n): n for n, *_ in self._target.bfs()}
self._unmatched_source_nodes = set(self._source_index) - set(pre_matched_nodes)
self._unmatched_target_nodes = set(self._target_index) - set(pre_matched_nodes.values())
self._bigram_histo_cache: t.Dict[int, t.DefaultDict[str, int]] = {}
@@ -199,10 +199,10 @@ class ChangeDistiller:
matching_set = leaves_matching_set.copy()
ordered_unmatched_source_nodes = {
- id(n[0]): None for n in self._source.bfs() if id(n[0]) in self._unmatched_source_nodes
+ id(n): None for n, *_ in self._source.bfs() if id(n) in self._unmatched_source_nodes
}
ordered_unmatched_target_nodes = {
- id(n[0]): None for n in self._target.bfs() if id(n[0]) in self._unmatched_target_nodes
+ id(n): None for n, *_ in self._target.bfs() if id(n) in self._unmatched_target_nodes
}
for source_node_id in ordered_unmatched_source_nodes:
@@ -304,18 +304,18 @@ class ChangeDistiller:
def _get_leaves(expression: exp.Expression) -> t.Iterator[exp.Expression]:
has_child_exprs = False
- for a in expression.args.values():
- for node in ensure_collection(a):
- if isinstance(node, exp.Expression):
- has_child_exprs = True
- yield from _get_leaves(node)
+ for _, node in expression.iter_expressions():
+ has_child_exprs = True
+ yield from _get_leaves(node)
if not has_child_exprs:
yield expression
def _is_same_type(source: exp.Expression, target: exp.Expression) -> bool:
- if type(source) is type(target):
+ if type(source) is type(target) and (
+ not isinstance(source, exp.Identifier) or type(source.parent) is type(target.parent)
+ ):
if isinstance(source, exp.Join):
return source.args.get("side") == target.args.get("side")
@@ -331,7 +331,7 @@ def _expression_only_args(expression: exp.Expression) -> t.List[exp.Expression]:
args: t.List[t.Union[exp.Expression, t.List]] = []
if expression:
for a in expression.args.values():
- args.extend(ensure_collection(a))
+ args.extend(ensure_list(a))
return [a for a in args if isinstance(a, exp.Expression)]