from collections import defaultdict from dataclasses import dataclass from heapq import heappop, heappush from sqlglot import Dialect from sqlglot import expressions as exp from sqlglot.helper import ensure_list @dataclass(frozen=True) class Insert: """Indicates that a new node has been inserted""" expression: exp.Expression @dataclass(frozen=True) class Remove: """Indicates that an existing node has been removed""" expression: exp.Expression @dataclass(frozen=True) class Move: """Indicates that an existing node's position within the tree has changed""" expression: exp.Expression @dataclass(frozen=True) class Update: """Indicates that an existing node has been updated""" source: exp.Expression target: exp.Expression @dataclass(frozen=True) class Keep: """Indicates that an existing node hasn't been changed""" source: exp.Expression target: exp.Expression def diff(source, target): """ Returns the list of changes between the source and the target expressions. Examples: >>> diff(parse_one("a + b"), parse_one("a + c")) [ Remove(expression=(COLUMN this: (IDENTIFIER this: b, quoted: False))), Insert(expression=(COLUMN this: (IDENTIFIER this: c, quoted: False))), Keep( source=(ADD this: ...), target=(ADD this: ...) ), Keep( source=(COLUMN this: (IDENTIFIER this: a, quoted: False)), target=(COLUMN this: (IDENTIFIER this: a, quoted: False)) ), ] Args: source (sqlglot.Expression): the source expression. target (sqlglot.Expression): the target expression against which the diff should be calculated. Returns: the list of Insert, Remove, Move, Update and Keep objects for each node in the source and the target expression trees. This list represents a sequence of steps needed to transform the source expression tree into the target one. """ return ChangeDistiller().diff(source.copy(), target.copy()) LEAF_EXPRESSION_TYPES = ( exp.Boolean, exp.DataType, exp.Identifier, exp.Literal, ) class ChangeDistiller: """ The implementation of the Change Distiller algorithm described by Beat Fluri and Martin Pinzger in their paper https://ieeexplore.ieee.org/document/4339230, which in turn is based on the algorithm by Chawathe et al. described in http://ilpubs.stanford.edu:8090/115/1/1995-46.pdf. """ def __init__(self, f=0.6, t=0.6): self.f = f self.t = t self._sql_generator = Dialect().generator() def diff(self, source, target): self._source = source self._target = target self._source_index = {id(n[0]): n[0] for n in source.bfs()} self._target_index = {id(n[0]): n[0] for n in target.bfs()} self._unmatched_source_nodes = set(self._source_index) self._unmatched_target_nodes = set(self._target_index) self._bigram_histo_cache = {} matching_set = self._compute_matching_set() return self._generate_edit_script(matching_set) def _generate_edit_script(self, matching_set): edit_script = [] for removed_node_id in self._unmatched_source_nodes: edit_script.append(Remove(self._source_index[removed_node_id])) for inserted_node_id in self._unmatched_target_nodes: edit_script.append(Insert(self._target_index[inserted_node_id])) for kept_source_node_id, kept_target_node_id in matching_set: source_node = self._source_index[kept_source_node_id] target_node = self._target_index[kept_target_node_id] if not isinstance(source_node, LEAF_EXPRESSION_TYPES) or source_node == target_node: edit_script.extend(self._generate_move_edits(source_node, target_node, matching_set)) edit_script.append(Keep(source_node, target_node)) else: edit_script.append(Update(source_node, target_node)) return edit_script def _generate_move_edits(self, source, target, matching_set): source_args = [id(e) for e in _expression_only_args(source)] target_args = [id(e) for e in _expression_only_args(target)] args_lcs = set(_lcs(source_args, target_args, lambda l, r: (l, r) in matching_set)) move_edits = [] for a in source_args: if a not in args_lcs and a not in self._unmatched_source_nodes: move_edits.append(Move(self._source_index[a])) return move_edits def _compute_matching_set(self): leaves_matching_set = self._compute_leaf_matching_set() matching_set = leaves_matching_set.copy() ordered_unmatched_source_nodes = { id(n[0]): None for n in self._source.bfs() if id(n[0]) in self._unmatched_source_nodes } ordered_unmatched_target_nodes = { id(n[0]): None for n in self._target.bfs() if id(n[0]) in self._unmatched_target_nodes } for source_node_id in ordered_unmatched_source_nodes: for target_node_id in ordered_unmatched_target_nodes: source_node = self._source_index[source_node_id] target_node = self._target_index[target_node_id] if _is_same_type(source_node, target_node): source_leaf_ids = {id(l) for l in _get_leaves(source_node)} target_leaf_ids = {id(l) for l in _get_leaves(target_node)} max_leaves_num = max(len(source_leaf_ids), len(target_leaf_ids)) if max_leaves_num: common_leaves_num = sum( 1 if s in source_leaf_ids and t in target_leaf_ids else 0 for s, t in leaves_matching_set ) leaf_similarity_score = common_leaves_num / max_leaves_num else: leaf_similarity_score = 0.0 adjusted_t = self.t if min(len(source_leaf_ids), len(target_leaf_ids)) > 4 else 0.4 if leaf_similarity_score >= 0.8 or ( leaf_similarity_score >= adjusted_t and self._dice_coefficient(source_node, target_node) >= self.f ): matching_set.add((source_node_id, target_node_id)) self._unmatched_source_nodes.remove(source_node_id) self._unmatched_target_nodes.remove(target_node_id) ordered_unmatched_target_nodes.pop(target_node_id, None) break return matching_set def _compute_leaf_matching_set(self): candidate_matchings = [] source_leaves = list(_get_leaves(self._source)) target_leaves = list(_get_leaves(self._target)) for source_leaf in source_leaves: for target_leaf in target_leaves: if _is_same_type(source_leaf, target_leaf): similarity_score = self._dice_coefficient(source_leaf, target_leaf) if similarity_score >= self.f: heappush( candidate_matchings, ( -similarity_score, len(candidate_matchings), source_leaf, target_leaf, ), ) # Pick best matchings based on the highest score matching_set = set() while candidate_matchings: _, _, source_leaf, target_leaf = heappop(candidate_matchings) if id(source_leaf) in self._unmatched_source_nodes and id(target_leaf) in self._unmatched_target_nodes: matching_set.add((id(source_leaf), id(target_leaf))) self._unmatched_source_nodes.remove(id(source_leaf)) self._unmatched_target_nodes.remove(id(target_leaf)) return matching_set def _dice_coefficient(self, source, target): source_histo = self._bigram_histo(source) target_histo = self._bigram_histo(target) total_grams = sum(source_histo.values()) + sum(target_histo.values()) if not total_grams: return 1.0 if source == target else 0.0 overlap_len = 0 overlapping_grams = set(source_histo) & set(target_histo) for g in overlapping_grams: overlap_len += min(source_histo[g], target_histo[g]) return 2 * overlap_len / total_grams def _bigram_histo(self, expression): if id(expression) in self._bigram_histo_cache: return self._bigram_histo_cache[id(expression)] expression_str = self._sql_generator.generate(expression) count = max(0, len(expression_str) - 1) bigram_histo = defaultdict(int) for i in range(count): bigram_histo[expression_str[i : i + 2]] += 1 self._bigram_histo_cache[id(expression)] = bigram_histo return bigram_histo def _get_leaves(expression): has_child_exprs = False for a in expression.args.values(): nodes = ensure_list(a) for node in nodes: if isinstance(node, exp.Expression): has_child_exprs = True yield from _get_leaves(node) if not has_child_exprs: yield expression def _is_same_type(source, target): if type(source) is type(target): if isinstance(source, exp.Join): return source.args.get("side") == target.args.get("side") if isinstance(source, exp.Anonymous): return source.this == target.this return True return False def _expression_only_args(expression): args = [] if expression: for a in expression.args.values(): args.extend(ensure_list(a)) return [a for a in args if isinstance(a, exp.Expression)] def _lcs(seq_a, seq_b, equal): """Calculates the longest common subsequence""" len_a = len(seq_a) len_b = len(seq_b) lcs_result = [[None] * (len_b + 1) for i in range(len_a + 1)] for i in range(len_a + 1): for j in range(len_b + 1): if i == 0 or j == 0: lcs_result[i][j] = [] elif equal(seq_a[i - 1], seq_b[j - 1]): lcs_result[i][j] = lcs_result[i - 1][j - 1] + [seq_a[i - 1]] else: lcs_result[i][j] = ( lcs_result[i - 1][j] if len(lcs_result[i - 1][j]) > len(lcs_result[i][j - 1]) else lcs_result[i][j - 1] ) return lcs_result[len_a][len_b]