summaryrefslogtreecommitdiffstats
path: root/sqlglot/diff.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/diff.py')
-rw-r--r--sqlglot/diff.py314
1 files changed, 314 insertions, 0 deletions
diff --git a/sqlglot/diff.py b/sqlglot/diff.py
new file mode 100644
index 0000000..8eeb4e9
--- /dev/null
+++ b/sqlglot/diff.py
@@ -0,0 +1,314 @@
+from collections import defaultdict
+from dataclasses import dataclass
+from heapq import heappop, heappush
+
+from sqlglot import Dialect
+from sqlglot import expressions as exp
+from sqlglot.helper import ensure_list
+
+
+@dataclass(frozen=True)
+class Insert:
+ """Indicates that a new node has been inserted"""
+
+ expression: exp.Expression
+
+
+@dataclass(frozen=True)
+class Remove:
+ """Indicates that an existing node has been removed"""
+
+ expression: exp.Expression
+
+
+@dataclass(frozen=True)
+class Move:
+ """Indicates that an existing node's position within the tree has changed"""
+
+ expression: exp.Expression
+
+
+@dataclass(frozen=True)
+class Update:
+ """Indicates that an existing node has been updated"""
+
+ source: exp.Expression
+ target: exp.Expression
+
+
+@dataclass(frozen=True)
+class Keep:
+ """Indicates that an existing node hasn't been changed"""
+
+ source: exp.Expression
+ target: exp.Expression
+
+
+def diff(source, target):
+ """
+ Returns the list of changes between the source and the target expressions.
+
+ Examples:
+ >>> diff(parse_one("a + b"), parse_one("a + c"))
+ [
+ Remove(expression=(COLUMN this: (IDENTIFIER this: b, quoted: False))),
+ Insert(expression=(COLUMN this: (IDENTIFIER this: c, quoted: False))),
+ Keep(
+ source=(ADD this: ...),
+ target=(ADD this: ...)
+ ),
+ Keep(
+ source=(COLUMN this: (IDENTIFIER this: a, quoted: False)),
+ target=(COLUMN this: (IDENTIFIER this: a, quoted: False))
+ ),
+ ]
+
+ Args:
+ source (sqlglot.Expression): the source expression.
+ target (sqlglot.Expression): the target expression against which the diff should be calculated.
+
+ Returns:
+ the list of Insert, Remove, Move, Update and Keep objects for each node in the source and the target expression trees.
+ This list represents a sequence of steps needed to transform the source expression tree into the target one.
+ """
+ return ChangeDistiller().diff(source.copy(), target.copy())
+
+
+LEAF_EXPRESSION_TYPES = (
+ exp.Boolean,
+ exp.DataType,
+ exp.Identifier,
+ exp.Literal,
+)
+
+
+class ChangeDistiller:
+ """
+ The implementation of the Change Distiller algorithm described by Beat Fluri and Martin Pinzger in
+ their paper https://ieeexplore.ieee.org/document/4339230, which in turn is based on the algorithm by
+ Chawathe et al. described in http://ilpubs.stanford.edu:8090/115/1/1995-46.pdf.
+ """
+
+ def __init__(self, f=0.6, t=0.6):
+ self.f = f
+ self.t = t
+ self._sql_generator = Dialect().generator()
+
+ def diff(self, source, target):
+ self._source = source
+ self._target = target
+ self._source_index = {id(n[0]): n[0] for n in source.bfs()}
+ self._target_index = {id(n[0]): n[0] for n in target.bfs()}
+ self._unmatched_source_nodes = set(self._source_index)
+ self._unmatched_target_nodes = set(self._target_index)
+ self._bigram_histo_cache = {}
+
+ matching_set = self._compute_matching_set()
+ return self._generate_edit_script(matching_set)
+
+ def _generate_edit_script(self, matching_set):
+ edit_script = []
+ for removed_node_id in self._unmatched_source_nodes:
+ edit_script.append(Remove(self._source_index[removed_node_id]))
+ for inserted_node_id in self._unmatched_target_nodes:
+ edit_script.append(Insert(self._target_index[inserted_node_id]))
+ for kept_source_node_id, kept_target_node_id in matching_set:
+ source_node = self._source_index[kept_source_node_id]
+ target_node = self._target_index[kept_target_node_id]
+ if (
+ not isinstance(source_node, LEAF_EXPRESSION_TYPES)
+ or source_node == target_node
+ ):
+ edit_script.extend(
+ self._generate_move_edits(source_node, target_node, matching_set)
+ )
+ edit_script.append(Keep(source_node, target_node))
+ else:
+ edit_script.append(Update(source_node, target_node))
+
+ return edit_script
+
+ def _generate_move_edits(self, source, target, matching_set):
+ source_args = [id(e) for e in _expression_only_args(source)]
+ target_args = [id(e) for e in _expression_only_args(target)]
+
+ args_lcs = set(
+ _lcs(source_args, target_args, lambda l, r: (l, r) in matching_set)
+ )
+
+ move_edits = []
+ for a in source_args:
+ if a not in args_lcs and a not in self._unmatched_source_nodes:
+ move_edits.append(Move(self._source_index[a]))
+
+ return move_edits
+
+ def _compute_matching_set(self):
+ leaves_matching_set = self._compute_leaf_matching_set()
+ matching_set = leaves_matching_set.copy()
+
+ ordered_unmatched_source_nodes = {
+ id(n[0]): None
+ for n in self._source.bfs()
+ if id(n[0]) in self._unmatched_source_nodes
+ }
+ ordered_unmatched_target_nodes = {
+ id(n[0]): None
+ for n in self._target.bfs()
+ if id(n[0]) in self._unmatched_target_nodes
+ }
+
+ for source_node_id in ordered_unmatched_source_nodes:
+ for target_node_id in ordered_unmatched_target_nodes:
+ source_node = self._source_index[source_node_id]
+ target_node = self._target_index[target_node_id]
+ if _is_same_type(source_node, target_node):
+ source_leaf_ids = {id(l) for l in _get_leaves(source_node)}
+ target_leaf_ids = {id(l) for l in _get_leaves(target_node)}
+
+ max_leaves_num = max(len(source_leaf_ids), len(target_leaf_ids))
+ if max_leaves_num:
+ common_leaves_num = sum(
+ 1 if s in source_leaf_ids and t in target_leaf_ids else 0
+ for s, t in leaves_matching_set
+ )
+ leaf_similarity_score = common_leaves_num / max_leaves_num
+ else:
+ leaf_similarity_score = 0.0
+
+ adjusted_t = (
+ self.t
+ if min(len(source_leaf_ids), len(target_leaf_ids)) > 4
+ else 0.4
+ )
+
+ if leaf_similarity_score >= 0.8 or (
+ leaf_similarity_score >= adjusted_t
+ and self._dice_coefficient(source_node, target_node) >= self.f
+ ):
+ matching_set.add((source_node_id, target_node_id))
+ self._unmatched_source_nodes.remove(source_node_id)
+ self._unmatched_target_nodes.remove(target_node_id)
+ ordered_unmatched_target_nodes.pop(target_node_id, None)
+ break
+
+ return matching_set
+
+ def _compute_leaf_matching_set(self):
+ candidate_matchings = []
+ source_leaves = list(_get_leaves(self._source))
+ target_leaves = list(_get_leaves(self._target))
+ for source_leaf in source_leaves:
+ for target_leaf in target_leaves:
+ if _is_same_type(source_leaf, target_leaf):
+ similarity_score = self._dice_coefficient(source_leaf, target_leaf)
+ if similarity_score >= self.f:
+ heappush(
+ candidate_matchings,
+ (
+ -similarity_score,
+ len(candidate_matchings),
+ source_leaf,
+ target_leaf,
+ ),
+ )
+
+ # Pick best matchings based on the highest score
+ matching_set = set()
+ while candidate_matchings:
+ _, _, source_leaf, target_leaf = heappop(candidate_matchings)
+ if (
+ id(source_leaf) in self._unmatched_source_nodes
+ and id(target_leaf) in self._unmatched_target_nodes
+ ):
+ matching_set.add((id(source_leaf), id(target_leaf)))
+ self._unmatched_source_nodes.remove(id(source_leaf))
+ self._unmatched_target_nodes.remove(id(target_leaf))
+
+ return matching_set
+
+ def _dice_coefficient(self, source, target):
+ source_histo = self._bigram_histo(source)
+ target_histo = self._bigram_histo(target)
+
+ total_grams = sum(source_histo.values()) + sum(target_histo.values())
+ if not total_grams:
+ return 1.0 if source == target else 0.0
+
+ overlap_len = 0
+ overlapping_grams = set(source_histo) & set(target_histo)
+ for g in overlapping_grams:
+ overlap_len += min(source_histo[g], target_histo[g])
+
+ return 2 * overlap_len / total_grams
+
+ def _bigram_histo(self, expression):
+ if id(expression) in self._bigram_histo_cache:
+ return self._bigram_histo_cache[id(expression)]
+
+ expression_str = self._sql_generator.generate(expression)
+ count = max(0, len(expression_str) - 1)
+ bigram_histo = defaultdict(int)
+ for i in range(count):
+ bigram_histo[expression_str[i : i + 2]] += 1
+
+ self._bigram_histo_cache[id(expression)] = bigram_histo
+ return bigram_histo
+
+
+def _get_leaves(expression):
+ has_child_exprs = False
+
+ for a in expression.args.values():
+ nodes = ensure_list(a)
+ for node in nodes:
+ if isinstance(node, exp.Expression):
+ has_child_exprs = True
+ yield from _get_leaves(node)
+
+ if not has_child_exprs:
+ yield expression
+
+
+def _is_same_type(source, target):
+ if type(source) is type(target):
+ if isinstance(source, exp.Join):
+ return source.args.get("side") == target.args.get("side")
+
+ if isinstance(source, exp.Anonymous):
+ return source.this == target.this
+
+ return True
+
+ return False
+
+
+def _expression_only_args(expression):
+ args = []
+ if expression:
+ for a in expression.args.values():
+ args.extend(ensure_list(a))
+ return [a for a in args if isinstance(a, exp.Expression)]
+
+
+def _lcs(seq_a, seq_b, equal):
+ """Calculates the longest common subsequence"""
+
+ len_a = len(seq_a)
+ len_b = len(seq_b)
+ lcs_result = [[None] * (len_b + 1) for i in range(len_a + 1)]
+
+ for i in range(len_a + 1):
+ for j in range(len_b + 1):
+ if i == 0 or j == 0:
+ lcs_result[i][j] = []
+ elif equal(seq_a[i - 1], seq_b[j - 1]):
+ lcs_result[i][j] = lcs_result[i - 1][j - 1] + [seq_a[i - 1]]
+ else:
+ lcs_result[i][j] = (
+ lcs_result[i - 1][j]
+ if len(lcs_result[i - 1][j]) > len(lcs_result[i][j - 1])
+ else lcs_result[i][j - 1]
+ )
+
+ return lcs_result[len_a][len_b]