From a8b22b4c5bdf9139a187c92b7b9f81bdeaa84888 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Mon, 27 Feb 2023 11:46:36 +0100 Subject: Merging upstream version 11.2.3. Signed-off-by: Daniel Baumann --- sqlglot/diff.py | 55 +++++++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 47 insertions(+), 8 deletions(-) (limited to 'sqlglot/diff.py') diff --git a/sqlglot/diff.py b/sqlglot/diff.py index 7530613..dddb9ad 100644 --- a/sqlglot/diff.py +++ b/sqlglot/diff.py @@ -11,8 +11,7 @@ from collections import defaultdict from dataclasses import dataclass from heapq import heappop, heappush -from sqlglot import Dialect -from sqlglot import expressions as exp +from sqlglot import Dialect, expressions as exp from sqlglot.helper import ensure_collection @@ -58,7 +57,12 @@ if t.TYPE_CHECKING: Edit = t.Union[Insert, Remove, Move, Update, Keep] -def diff(source: exp.Expression, target: exp.Expression) -> t.List[Edit]: +def diff( + source: exp.Expression, + target: exp.Expression, + matchings: t.List[t.Tuple[exp.Expression, exp.Expression]] | None = None, + **kwargs: t.Any, +) -> t.List[Edit]: """ Returns the list of changes between the source and the target expressions. @@ -80,13 +84,38 @@ def diff(source: exp.Expression, target: exp.Expression) -> t.List[Edit]: Args: source: the source expression. target: the target expression against which the diff should be calculated. + matchings: the list of pre-matched node pairs which is used to help the algorithm's + heuristics produce better results for subtrees that are known by a caller to be matching. + Note: expression references in this list must refer to the same node objects that are + referenced in source / target trees. Returns: the list of Insert, Remove, Move, Update and Keep objects for each node in the source and the target expression trees. This list represents a sequence of steps needed to transform the source expression tree into the target one. """ - return ChangeDistiller().diff(source.copy(), target.copy()) + matchings = matchings or [] + matching_ids = {id(n) for pair in matchings for n in pair} + + def compute_node_mappings( + original: exp.Expression, copy: exp.Expression + ) -> t.Dict[int, exp.Expression]: + return { + id(old_node): new_node + for (old_node, _, _), (new_node, _, _) in zip(original.walk(), copy.walk()) + if id(old_node) in matching_ids + } + + source_copy = source.copy() + target_copy = target.copy() + + node_mappings = { + **compute_node_mappings(source, source_copy), + **compute_node_mappings(target, target_copy), + } + matchings_copy = [(node_mappings[id(s)], node_mappings[id(t)]) for s, t in matchings] + + return ChangeDistiller(**kwargs).diff(source_copy, target_copy, matchings=matchings_copy) LEAF_EXPRESSION_TYPES = ( @@ -109,16 +138,26 @@ class ChangeDistiller: self.t = t self._sql_generator = Dialect().generator() - def diff(self, source: exp.Expression, target: exp.Expression) -> t.List[Edit]: + def diff( + self, + source: exp.Expression, + target: exp.Expression, + matchings: t.List[t.Tuple[exp.Expression, exp.Expression]] | None = None, + ) -> t.List[Edit]: + matchings = matchings or [] + pre_matched_nodes = {id(s): id(t) for s, t in matchings} + if len({n for pair in pre_matched_nodes.items() for n in pair}) != 2 * len(matchings): + raise ValueError("Each node can be referenced at most once in the list of matchings") + self._source = source self._target = target self._source_index = {id(n[0]): n[0] for n in source.bfs()} self._target_index = {id(n[0]): n[0] for n in target.bfs()} - self._unmatched_source_nodes = set(self._source_index) - self._unmatched_target_nodes = set(self._target_index) + self._unmatched_source_nodes = set(self._source_index) - set(pre_matched_nodes) + self._unmatched_target_nodes = set(self._target_index) - set(pre_matched_nodes.values()) self._bigram_histo_cache: t.Dict[int, t.DefaultDict[str, int]] = {} - matching_set = self._compute_matching_set() + matching_set = self._compute_matching_set() | {(s, t) for s, t in pre_matched_nodes.items()} return self._generate_edit_script(matching_set) def _generate_edit_script(self, matching_set: t.Set[t.Tuple[int, int]]) -> t.List[Edit]: -- cgit v1.2.3