summaryrefslogtreecommitdiffstats
path: root/sqlglot/diff.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/diff.py')
-rw-r--r--sqlglot/diff.py23
1 files changed, 15 insertions, 8 deletions
diff --git a/sqlglot/diff.py b/sqlglot/diff.py
index 0567c12..2d959ab 100644
--- a/sqlglot/diff.py
+++ b/sqlglot/diff.py
@@ -4,7 +4,7 @@ from heapq import heappop, heappush
from sqlglot import Dialect
from sqlglot import expressions as exp
-from sqlglot.helper import ensure_list
+from sqlglot.helper import ensure_collection
@dataclass(frozen=True)
@@ -116,7 +116,9 @@ class ChangeDistiller:
source_node = self._source_index[kept_source_node_id]
target_node = self._target_index[kept_target_node_id]
if not isinstance(source_node, LEAF_EXPRESSION_TYPES) or source_node == target_node:
- edit_script.extend(self._generate_move_edits(source_node, target_node, matching_set))
+ edit_script.extend(
+ self._generate_move_edits(source_node, target_node, matching_set)
+ )
edit_script.append(Keep(source_node, target_node))
else:
edit_script.append(Update(source_node, target_node))
@@ -158,13 +160,16 @@ class ChangeDistiller:
max_leaves_num = max(len(source_leaf_ids), len(target_leaf_ids))
if max_leaves_num:
common_leaves_num = sum(
- 1 if s in source_leaf_ids and t in target_leaf_ids else 0 for s, t in leaves_matching_set
+ 1 if s in source_leaf_ids and t in target_leaf_ids else 0
+ for s, t in leaves_matching_set
)
leaf_similarity_score = common_leaves_num / max_leaves_num
else:
leaf_similarity_score = 0.0
- adjusted_t = self.t if min(len(source_leaf_ids), len(target_leaf_ids)) > 4 else 0.4
+ adjusted_t = (
+ self.t if min(len(source_leaf_ids), len(target_leaf_ids)) > 4 else 0.4
+ )
if leaf_similarity_score >= 0.8 or (
leaf_similarity_score >= adjusted_t
@@ -201,7 +206,10 @@ class ChangeDistiller:
matching_set = set()
while candidate_matchings:
_, _, source_leaf, target_leaf = heappop(candidate_matchings)
- if id(source_leaf) in self._unmatched_source_nodes and id(target_leaf) in self._unmatched_target_nodes:
+ if (
+ id(source_leaf) in self._unmatched_source_nodes
+ and id(target_leaf) in self._unmatched_target_nodes
+ ):
matching_set.add((id(source_leaf), id(target_leaf)))
self._unmatched_source_nodes.remove(id(source_leaf))
self._unmatched_target_nodes.remove(id(target_leaf))
@@ -241,8 +249,7 @@ def _get_leaves(expression):
has_child_exprs = False
for a in expression.args.values():
- nodes = ensure_list(a)
- for node in nodes:
+ for node in ensure_collection(a):
if isinstance(node, exp.Expression):
has_child_exprs = True
yield from _get_leaves(node)
@@ -268,7 +275,7 @@ def _expression_only_args(expression):
args = []
if expression:
for a in expression.args.values():
- args.extend(ensure_list(a))
+ args.extend(ensure_collection(a))
return [a for a in args if isinstance(a, exp.Expression)]