summaryrefslogtreecommitdiffstats
path: root/sqlglot/diff.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/diff.py')
-rw-r--r--sqlglot/diff.py55
1 files changed, 33 insertions, 22 deletions
diff --git a/sqlglot/diff.py b/sqlglot/diff.py
index 2d959ab..758ad1b 100644
--- a/sqlglot/diff.py
+++ b/sqlglot/diff.py
@@ -1,3 +1,6 @@
+from __future__ import annotations
+
+import typing as t
from collections import defaultdict
from dataclasses import dataclass
from heapq import heappop, heappush
@@ -6,6 +9,10 @@ from sqlglot import Dialect
from sqlglot import expressions as exp
from sqlglot.helper import ensure_collection
+if t.TYPE_CHECKING:
+ T = t.TypeVar("T")
+ Edit = t.Union[Insert, Remove, Move, Update, Keep]
+
@dataclass(frozen=True)
class Insert:
@@ -44,7 +51,7 @@ class Keep:
target: exp.Expression
-def diff(source, target):
+def diff(source: exp.Expression, target: exp.Expression) -> t.List[Edit]:
"""
Returns the list of changes between the source and the target expressions.
@@ -89,25 +96,25 @@ class ChangeDistiller:
Chawathe et al. described in http://ilpubs.stanford.edu:8090/115/1/1995-46.pdf.
"""
- def __init__(self, f=0.6, t=0.6):
+ def __init__(self, f: float = 0.6, t: float = 0.6) -> None:
self.f = f
self.t = t
self._sql_generator = Dialect().generator()
- def diff(self, source, target):
+ def diff(self, source: exp.Expression, target: exp.Expression) -> t.List[Edit]:
self._source = source
self._target = target
self._source_index = {id(n[0]): n[0] for n in source.bfs()}
self._target_index = {id(n[0]): n[0] for n in target.bfs()}
self._unmatched_source_nodes = set(self._source_index)
self._unmatched_target_nodes = set(self._target_index)
- self._bigram_histo_cache = {}
+ self._bigram_histo_cache: t.Dict[int, t.DefaultDict[str, int]] = {}
matching_set = self._compute_matching_set()
return self._generate_edit_script(matching_set)
- def _generate_edit_script(self, matching_set):
- edit_script = []
+ def _generate_edit_script(self, matching_set: t.Set[t.Tuple[int, int]]) -> t.List[Edit]:
+ edit_script: t.List[Edit] = []
for removed_node_id in self._unmatched_source_nodes:
edit_script.append(Remove(self._source_index[removed_node_id]))
for inserted_node_id in self._unmatched_target_nodes:
@@ -125,7 +132,9 @@ class ChangeDistiller:
return edit_script
- def _generate_move_edits(self, source, target, matching_set):
+ def _generate_move_edits(
+ self, source: exp.Expression, target: exp.Expression, matching_set: t.Set[t.Tuple[int, int]]
+ ) -> t.List[Move]:
source_args = [id(e) for e in _expression_only_args(source)]
target_args = [id(e) for e in _expression_only_args(target)]
@@ -138,7 +147,7 @@ class ChangeDistiller:
return move_edits
- def _compute_matching_set(self):
+ def _compute_matching_set(self) -> t.Set[t.Tuple[int, int]]:
leaves_matching_set = self._compute_leaf_matching_set()
matching_set = leaves_matching_set.copy()
@@ -183,8 +192,8 @@ class ChangeDistiller:
return matching_set
- def _compute_leaf_matching_set(self):
- candidate_matchings = []
+ def _compute_leaf_matching_set(self) -> t.Set[t.Tuple[int, int]]:
+ candidate_matchings: t.List[t.Tuple[float, int, exp.Expression, exp.Expression]] = []
source_leaves = list(_get_leaves(self._source))
target_leaves = list(_get_leaves(self._target))
for source_leaf in source_leaves:
@@ -216,7 +225,7 @@ class ChangeDistiller:
return matching_set
- def _dice_coefficient(self, source, target):
+ def _dice_coefficient(self, source: exp.Expression, target: exp.Expression) -> float:
source_histo = self._bigram_histo(source)
target_histo = self._bigram_histo(target)
@@ -231,13 +240,13 @@ class ChangeDistiller:
return 2 * overlap_len / total_grams
- def _bigram_histo(self, expression):
+ def _bigram_histo(self, expression: exp.Expression) -> t.DefaultDict[str, int]:
if id(expression) in self._bigram_histo_cache:
return self._bigram_histo_cache[id(expression)]
expression_str = self._sql_generator.generate(expression)
count = max(0, len(expression_str) - 1)
- bigram_histo = defaultdict(int)
+ bigram_histo: t.DefaultDict[str, int] = defaultdict(int)
for i in range(count):
bigram_histo[expression_str[i : i + 2]] += 1
@@ -245,7 +254,7 @@ class ChangeDistiller:
return bigram_histo
-def _get_leaves(expression):
+def _get_leaves(expression: exp.Expression) -> t.Generator[exp.Expression, None, None]:
has_child_exprs = False
for a in expression.args.values():
@@ -258,7 +267,7 @@ def _get_leaves(expression):
yield expression
-def _is_same_type(source, target):
+def _is_same_type(source: exp.Expression, target: exp.Expression) -> bool:
if type(source) is type(target):
if isinstance(source, exp.Join):
return source.args.get("side") == target.args.get("side")
@@ -271,15 +280,17 @@ def _is_same_type(source, target):
return False
-def _expression_only_args(expression):
- args = []
+def _expression_only_args(expression: exp.Expression) -> t.List[exp.Expression]:
+ args: t.List[t.Union[exp.Expression, t.List]] = []
if expression:
for a in expression.args.values():
args.extend(ensure_collection(a))
return [a for a in args if isinstance(a, exp.Expression)]
-def _lcs(seq_a, seq_b, equal):
+def _lcs(
+ seq_a: t.Sequence[T], seq_b: t.Sequence[T], equal: t.Callable[[T, T], bool]
+) -> t.Sequence[t.Optional[T]]:
"""Calculates the longest common subsequence"""
len_a = len(seq_a)
@@ -289,14 +300,14 @@ def _lcs(seq_a, seq_b, equal):
for i in range(len_a + 1):
for j in range(len_b + 1):
if i == 0 or j == 0:
- lcs_result[i][j] = []
+ lcs_result[i][j] = [] # type: ignore
elif equal(seq_a[i - 1], seq_b[j - 1]):
- lcs_result[i][j] = lcs_result[i - 1][j - 1] + [seq_a[i - 1]]
+ lcs_result[i][j] = lcs_result[i - 1][j - 1] + [seq_a[i - 1]] # type: ignore
else:
lcs_result[i][j] = (
lcs_result[i - 1][j]
- if len(lcs_result[i - 1][j]) > len(lcs_result[i][j - 1])
+ if len(lcs_result[i - 1][j]) > len(lcs_result[i][j - 1]) # type: ignore
else lcs_result[i][j - 1]
)
- return lcs_result[len_a][len_b]
+ return lcs_result[len_a][len_b] # type: ignore