diff options
Diffstat (limited to 'sqlglot/diff.py')
-rw-r--r-- | sqlglot/diff.py | 55 |
1 files changed, 33 insertions, 22 deletions
diff --git a/sqlglot/diff.py b/sqlglot/diff.py index 2d959ab..758ad1b 100644 --- a/sqlglot/diff.py +++ b/sqlglot/diff.py @@ -1,3 +1,6 @@ +from __future__ import annotations + +import typing as t from collections import defaultdict from dataclasses import dataclass from heapq import heappop, heappush @@ -6,6 +9,10 @@ from sqlglot import Dialect from sqlglot import expressions as exp from sqlglot.helper import ensure_collection +if t.TYPE_CHECKING: + T = t.TypeVar("T") + Edit = t.Union[Insert, Remove, Move, Update, Keep] + @dataclass(frozen=True) class Insert: @@ -44,7 +51,7 @@ class Keep: target: exp.Expression -def diff(source, target): +def diff(source: exp.Expression, target: exp.Expression) -> t.List[Edit]: """ Returns the list of changes between the source and the target expressions. @@ -89,25 +96,25 @@ class ChangeDistiller: Chawathe et al. described in http://ilpubs.stanford.edu:8090/115/1/1995-46.pdf. """ - def __init__(self, f=0.6, t=0.6): + def __init__(self, f: float = 0.6, t: float = 0.6) -> None: self.f = f self.t = t self._sql_generator = Dialect().generator() - def diff(self, source, target): + def diff(self, source: exp.Expression, target: exp.Expression) -> t.List[Edit]: self._source = source self._target = target self._source_index = {id(n[0]): n[0] for n in source.bfs()} self._target_index = {id(n[0]): n[0] for n in target.bfs()} self._unmatched_source_nodes = set(self._source_index) self._unmatched_target_nodes = set(self._target_index) - self._bigram_histo_cache = {} + self._bigram_histo_cache: t.Dict[int, t.DefaultDict[str, int]] = {} matching_set = self._compute_matching_set() return self._generate_edit_script(matching_set) - def _generate_edit_script(self, matching_set): - edit_script = [] + def _generate_edit_script(self, matching_set: t.Set[t.Tuple[int, int]]) -> t.List[Edit]: + edit_script: t.List[Edit] = [] for removed_node_id in self._unmatched_source_nodes: edit_script.append(Remove(self._source_index[removed_node_id])) for inserted_node_id in self._unmatched_target_nodes: @@ -125,7 +132,9 @@ class ChangeDistiller: return edit_script - def _generate_move_edits(self, source, target, matching_set): + def _generate_move_edits( + self, source: exp.Expression, target: exp.Expression, matching_set: t.Set[t.Tuple[int, int]] + ) -> t.List[Move]: source_args = [id(e) for e in _expression_only_args(source)] target_args = [id(e) for e in _expression_only_args(target)] @@ -138,7 +147,7 @@ class ChangeDistiller: return move_edits - def _compute_matching_set(self): + def _compute_matching_set(self) -> t.Set[t.Tuple[int, int]]: leaves_matching_set = self._compute_leaf_matching_set() matching_set = leaves_matching_set.copy() @@ -183,8 +192,8 @@ class ChangeDistiller: return matching_set - def _compute_leaf_matching_set(self): - candidate_matchings = [] + def _compute_leaf_matching_set(self) -> t.Set[t.Tuple[int, int]]: + candidate_matchings: t.List[t.Tuple[float, int, exp.Expression, exp.Expression]] = [] source_leaves = list(_get_leaves(self._source)) target_leaves = list(_get_leaves(self._target)) for source_leaf in source_leaves: @@ -216,7 +225,7 @@ class ChangeDistiller: return matching_set - def _dice_coefficient(self, source, target): + def _dice_coefficient(self, source: exp.Expression, target: exp.Expression) -> float: source_histo = self._bigram_histo(source) target_histo = self._bigram_histo(target) @@ -231,13 +240,13 @@ class ChangeDistiller: return 2 * overlap_len / total_grams - def _bigram_histo(self, expression): + def _bigram_histo(self, expression: exp.Expression) -> t.DefaultDict[str, int]: if id(expression) in self._bigram_histo_cache: return self._bigram_histo_cache[id(expression)] expression_str = self._sql_generator.generate(expression) count = max(0, len(expression_str) - 1) - bigram_histo = defaultdict(int) + bigram_histo: t.DefaultDict[str, int] = defaultdict(int) for i in range(count): bigram_histo[expression_str[i : i + 2]] += 1 @@ -245,7 +254,7 @@ class ChangeDistiller: return bigram_histo -def _get_leaves(expression): +def _get_leaves(expression: exp.Expression) -> t.Generator[exp.Expression, None, None]: has_child_exprs = False for a in expression.args.values(): @@ -258,7 +267,7 @@ def _get_leaves(expression): yield expression -def _is_same_type(source, target): +def _is_same_type(source: exp.Expression, target: exp.Expression) -> bool: if type(source) is type(target): if isinstance(source, exp.Join): return source.args.get("side") == target.args.get("side") @@ -271,15 +280,17 @@ def _is_same_type(source, target): return False -def _expression_only_args(expression): - args = [] +def _expression_only_args(expression: exp.Expression) -> t.List[exp.Expression]: + args: t.List[t.Union[exp.Expression, t.List]] = [] if expression: for a in expression.args.values(): args.extend(ensure_collection(a)) return [a for a in args if isinstance(a, exp.Expression)] -def _lcs(seq_a, seq_b, equal): +def _lcs( + seq_a: t.Sequence[T], seq_b: t.Sequence[T], equal: t.Callable[[T, T], bool] +) -> t.Sequence[t.Optional[T]]: """Calculates the longest common subsequence""" len_a = len(seq_a) @@ -289,14 +300,14 @@ def _lcs(seq_a, seq_b, equal): for i in range(len_a + 1): for j in range(len_b + 1): if i == 0 or j == 0: - lcs_result[i][j] = [] + lcs_result[i][j] = [] # type: ignore elif equal(seq_a[i - 1], seq_b[j - 1]): - lcs_result[i][j] = lcs_result[i - 1][j - 1] + [seq_a[i - 1]] + lcs_result[i][j] = lcs_result[i - 1][j - 1] + [seq_a[i - 1]] # type: ignore else: lcs_result[i][j] = ( lcs_result[i - 1][j] - if len(lcs_result[i - 1][j]) > len(lcs_result[i][j - 1]) + if len(lcs_result[i - 1][j]) > len(lcs_result[i][j - 1]) # type: ignore else lcs_result[i][j - 1] ) - return lcs_result[len_a][len_b] + return lcs_result[len_a][len_b] # type: ignore |