summaryrefslogtreecommitdiffstats
path: root/sqlglot/diff.py
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2023-02-27 10:46:36 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2023-02-27 10:46:36 +0000
commita8b22b4c5bdf9139a187c92b7b9f81bdeaa84888 (patch)
tree93b8523df3ce9e02e435f56e493bd9b724eb9c7c /sqlglot/diff.py
parentReleasing debian version 11.2.0-1. (diff)
downloadsqlglot-a8b22b4c5bdf9139a187c92b7b9f81bdeaa84888.tar.xz
sqlglot-a8b22b4c5bdf9139a187c92b7b9f81bdeaa84888.zip
Merging upstream version 11.2.3.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/diff.py')
-rw-r--r--sqlglot/diff.py55
1 files changed, 47 insertions, 8 deletions
diff --git a/sqlglot/diff.py b/sqlglot/diff.py
index 7530613..dddb9ad 100644
--- a/sqlglot/diff.py
+++ b/sqlglot/diff.py
@@ -11,8 +11,7 @@ from collections import defaultdict
from dataclasses import dataclass
from heapq import heappop, heappush
-from sqlglot import Dialect
-from sqlglot import expressions as exp
+from sqlglot import Dialect, expressions as exp
from sqlglot.helper import ensure_collection
@@ -58,7 +57,12 @@ if t.TYPE_CHECKING:
Edit = t.Union[Insert, Remove, Move, Update, Keep]
-def diff(source: exp.Expression, target: exp.Expression) -> t.List[Edit]:
+def diff(
+ source: exp.Expression,
+ target: exp.Expression,
+ matchings: t.List[t.Tuple[exp.Expression, exp.Expression]] | None = None,
+ **kwargs: t.Any,
+) -> t.List[Edit]:
"""
Returns the list of changes between the source and the target expressions.
@@ -80,13 +84,38 @@ def diff(source: exp.Expression, target: exp.Expression) -> t.List[Edit]:
Args:
source: the source expression.
target: the target expression against which the diff should be calculated.
+ matchings: the list of pre-matched node pairs which is used to help the algorithm's
+ heuristics produce better results for subtrees that are known by a caller to be matching.
+ Note: expression references in this list must refer to the same node objects that are
+ referenced in source / target trees.
Returns:
the list of Insert, Remove, Move, Update and Keep objects for each node in the source and the
target expression trees. This list represents a sequence of steps needed to transform the source
expression tree into the target one.
"""
- return ChangeDistiller().diff(source.copy(), target.copy())
+ matchings = matchings or []
+ matching_ids = {id(n) for pair in matchings for n in pair}
+
+ def compute_node_mappings(
+ original: exp.Expression, copy: exp.Expression
+ ) -> t.Dict[int, exp.Expression]:
+ return {
+ id(old_node): new_node
+ for (old_node, _, _), (new_node, _, _) in zip(original.walk(), copy.walk())
+ if id(old_node) in matching_ids
+ }
+
+ source_copy = source.copy()
+ target_copy = target.copy()
+
+ node_mappings = {
+ **compute_node_mappings(source, source_copy),
+ **compute_node_mappings(target, target_copy),
+ }
+ matchings_copy = [(node_mappings[id(s)], node_mappings[id(t)]) for s, t in matchings]
+
+ return ChangeDistiller(**kwargs).diff(source_copy, target_copy, matchings=matchings_copy)
LEAF_EXPRESSION_TYPES = (
@@ -109,16 +138,26 @@ class ChangeDistiller:
self.t = t
self._sql_generator = Dialect().generator()
- def diff(self, source: exp.Expression, target: exp.Expression) -> t.List[Edit]:
+ def diff(
+ self,
+ source: exp.Expression,
+ target: exp.Expression,
+ matchings: t.List[t.Tuple[exp.Expression, exp.Expression]] | None = None,
+ ) -> t.List[Edit]:
+ matchings = matchings or []
+ pre_matched_nodes = {id(s): id(t) for s, t in matchings}
+ if len({n for pair in pre_matched_nodes.items() for n in pair}) != 2 * len(matchings):
+ raise ValueError("Each node can be referenced at most once in the list of matchings")
+
self._source = source
self._target = target
self._source_index = {id(n[0]): n[0] for n in source.bfs()}
self._target_index = {id(n[0]): n[0] for n in target.bfs()}
- self._unmatched_source_nodes = set(self._source_index)
- self._unmatched_target_nodes = set(self._target_index)
+ self._unmatched_source_nodes = set(self._source_index) - set(pre_matched_nodes)
+ self._unmatched_target_nodes = set(self._target_index) - set(pre_matched_nodes.values())
self._bigram_histo_cache: t.Dict[int, t.DefaultDict[str, int]] = {}
- matching_set = self._compute_matching_set()
+ matching_set = self._compute_matching_set() | {(s, t) for s, t in pre_matched_nodes.items()}
return self._generate_edit_script(matching_set)
def _generate_edit_script(self, matching_set: t.Set[t.Tuple[int, int]]) -> t.List[Edit]: