import unittest from sqlglot import exp, parse_one from sqlglot.diff import Insert, Move, Remove, Update, diff def diff_delta_only(source, target, matchings=None, **kwargs): return diff(source, target, matchings=matchings, delta_only=True, **kwargs) class TestDiff(unittest.TestCase): def test_simple(self): self._validate_delta_only( diff_delta_only(parse_one("SELECT a + b"), parse_one("SELECT a - b")), [ Remove(expression=parse_one("a + b")), # the Add node Insert(expression=parse_one("a - b")), # the Sub node Move(source=parse_one("a"), target=parse_one("a")), # the `a` Column node Move(source=parse_one("b"), target=parse_one("b")), # the `b` Column node ], ) self._validate_delta_only( diff_delta_only(parse_one("SELECT a, b, c"), parse_one("SELECT a, c")), [ Remove(expression=parse_one("b")), # the Column node ], ) self._validate_delta_only( diff_delta_only(parse_one("SELECT a, b"), parse_one("SELECT a, b, c")), [ Insert(expression=parse_one("c")), # the Column node ], ) self._validate_delta_only( diff_delta_only( parse_one("SELECT a FROM table_one"), parse_one("SELECT a FROM table_two"), ), [ Update( source=exp.to_table("table_one", quoted=False), target=exp.to_table("table_two", quoted=False), ), # the Table node ], ) def test_lambda(self): self._validate_delta_only( diff_delta_only( parse_one("SELECT a, b, c, x(a -> a)"), parse_one("SELECT a, b, c, x(b -> b)") ), [ Update( source=exp.Lambda( this=exp.to_identifier("a"), expressions=[exp.to_identifier("a")] ), target=exp.Lambda( this=exp.to_identifier("b"), expressions=[exp.to_identifier("b")] ), ), ], ) def test_udf(self): self._validate_delta_only( diff_delta_only( parse_one('SELECT a, b, "my.udf1"()'), parse_one('SELECT a, b, "my.udf2"()') ), [ Insert(expression=parse_one('"my.udf2"()')), Remove(expression=parse_one('"my.udf1"()')), ], ) self._validate_delta_only( diff_delta_only( parse_one('SELECT a, b, "my.udf"(x, y, z)'), parse_one('SELECT a, b, "my.udf"(x, y, w)'), ), [ Insert(expression=exp.column("w")), Remove(expression=exp.column("z")), ], ) def test_node_position_changed(self): expr_src = parse_one("SELECT a, b, c") expr_tgt = parse_one("SELECT c, a, b") self._validate_delta_only( diff_delta_only(expr_src, expr_tgt), [ Move(source=expr_src.selects[2], target=expr_tgt.selects[0]), ], ) expr_src = parse_one("SELECT a + b") expr_tgt = parse_one("SELECT b + a") self._validate_delta_only( diff_delta_only(expr_src, expr_tgt), [ Move(source=expr_src.selects[0].left, target=expr_tgt.selects[0].right), ], ) expr_src = parse_one("SELECT aaaa AND bbbb") expr_tgt = parse_one("SELECT bbbb AND aaaa") self._validate_delta_only( diff_delta_only(expr_src, expr_tgt), [ Move(source=expr_src.selects[0].left, target=expr_tgt.selects[0].right), ], ) expr_src = parse_one("SELECT aaaa OR bbbb OR cccc") expr_tgt = parse_one("SELECT cccc OR bbbb OR aaaa") self._validate_delta_only( diff_delta_only(expr_src, expr_tgt), [ Move(source=expr_src.selects[0].left.left, target=expr_tgt.selects[0].right), Move(source=expr_src.selects[0].right, target=expr_tgt.selects[0].left.left), ], ) expr_src = parse_one("SELECT a, b FROM t WHERE CONCAT('a', 'b') = 'ab'") expr_tgt = parse_one("SELECT a FROM t WHERE CONCAT('a', 'b', b) = 'ab'") self._validate_delta_only( diff_delta_only(expr_src, expr_tgt), [ Move(source=expr_src.selects[1], target=expr_tgt.find(exp.Concat).expressions[-1]), ], ) expr_src = parse_one("SELECT a as a, b as b FROM t WHERE CONCAT('a', 'b') = 'ab'") expr_tgt = parse_one("SELECT a as a FROM t WHERE CONCAT('a', 'b', b) = 'ab'") b_alias = expr_src.selects[1] self._validate_delta_only( diff_delta_only(expr_src, expr_tgt), [ Remove(expression=b_alias), Move(source=b_alias.this, target=expr_tgt.find(exp.Concat).expressions[-1]), ], ) def test_cte(self): expr_src = """ WITH cte1 AS (SELECT a, b, LOWER(c) AS c FROM table_one WHERE d = 'filter'), cte2 AS (SELECT d, e, f FROM table_two) SELECT a, b, d, e FROM cte1 JOIN cte2 ON f = c """ expr_tgt = """ WITH cte1 AS (SELECT a, b, c FROM table_one WHERE d = 'different_filter'), cte2 AS (SELECT d, e, f FROM table_two) SELECT a, b, d, e FROM cte1 JOIN cte2 ON f = c """ self._validate_delta_only( diff_delta_only(parse_one(expr_src), parse_one(expr_tgt)), [ Remove(expression=parse_one("LOWER(c) AS c")), # the Alias node Remove(expression=parse_one("LOWER(c)")), # the Lower node Remove(expression=parse_one("'filter'")), # the Literal node Insert(expression=parse_one("'different_filter'")), # the Literal node Move(source=parse_one("c"), target=parse_one("c")), # the new Column c ], ) def test_join(self): expr_src = parse_one("SELECT a, b FROM t1 LEFT JOIN t2 ON t1.key = t2.key") expr_tgt = parse_one("SELECT a, b FROM t1 RIGHT JOIN t2 ON t1.key = t2.key") src_join = expr_src.find(exp.Join) tgt_join = expr_tgt.find(exp.Join) self._validate_delta_only( diff_delta_only(expr_src, expr_tgt), [ Remove(expression=src_join), Insert(expression=tgt_join), Move(source=exp.to_table("t2"), target=exp.to_table("t2")), Move(source=src_join.args["on"], target=tgt_join.args["on"]), ], ) def test_window_functions(self): expr_src = parse_one("SELECT ROW_NUMBER() OVER (PARTITION BY a ORDER BY b)") expr_tgt = parse_one("SELECT RANK() OVER (PARTITION BY a ORDER BY b)") self._validate_delta_only(diff_delta_only(expr_src, expr_src), []) self._validate_delta_only( diff_delta_only(expr_src, expr_tgt), [ Remove(expression=parse_one("ROW_NUMBER()")), Insert(expression=parse_one("RANK()")), Update(source=expr_src.selects[0], target=expr_tgt.selects[0]), ], ) expr_src = parse_one("SELECT MAX(x) OVER (ORDER BY y) FROM z", "oracle") expr_tgt = parse_one("SELECT MAX(x) KEEP (DENSE_RANK LAST ORDER BY y) FROM z", "oracle") self._validate_delta_only( diff_delta_only(expr_src, expr_tgt), [Update(source=expr_src.selects[0], target=expr_tgt.selects[0])], ) def test_pre_matchings(self): expr_src = parse_one("SELECT 1") expr_tgt = parse_one("SELECT 1, 2, 3, 4") self._validate_delta_only( diff_delta_only(expr_src, expr_tgt), [ Remove(expression=expr_src), Insert(expression=expr_tgt), Insert(expression=exp.Literal.number(2)), Insert(expression=exp.Literal.number(3)), Insert(expression=exp.Literal.number(4)), Move(source=exp.Literal.number(1), target=exp.Literal.number(1)), ], ) self._validate_delta_only( diff_delta_only(expr_src, expr_tgt, matchings=[(expr_src, expr_tgt)]), [ Insert(expression=exp.Literal.number(2)), Insert(expression=exp.Literal.number(3)), Insert(expression=exp.Literal.number(4)), ], ) with self.assertRaises(ValueError): diff_delta_only( expr_src, expr_tgt, matchings=[(expr_src, expr_tgt), (expr_src, expr_tgt)] ) def test_identifier(self): expr_src = parse_one("SELECT a FROM tbl") expr_tgt = parse_one("SELECT a, tbl.b from tbl") self._validate_delta_only( diff_delta_only(expr_src, expr_tgt), [ Insert(expression=exp.to_column("tbl.b")), ], ) expr_src = parse_one("SELECT 1 AS c1, 2 AS c2") expr_tgt = parse_one("SELECT 2 AS c1, 3 AS c2") self._validate_delta_only( diff_delta_only(expr_src, expr_tgt), [ Remove(expression=exp.alias_(1, "c1")), Remove(expression=exp.Literal.number(1)), Insert(expression=exp.alias_(3, "c2")), Insert(expression=exp.Literal.number(3)), Update(source=exp.alias_(2, "c2"), target=exp.alias_(2, "c1")), ], ) def test_dialect_aware_diff(self): from sqlglot.generator import logger with self.assertLogs(logger) as cm: # We want to assert there are no warnings, but the 'assertLogs' method does not support that. # Therefore, we are adding a dummy warning, and then we will assert it is the only warning. logger.warning("Dummy warning") expression = parse_one("SELECT foo FROM bar FOR UPDATE", dialect="oracle") self._validate_delta_only( diff_delta_only(expression, expression.copy(), dialect="oracle"), [] ) self.assertEqual(["WARNING:sqlglot:Dummy warning"], cm.output) def test_non_expression_leaf_delta(self): expr_src = parse_one("SELECT a UNION SELECT b") expr_tgt = parse_one("SELECT a UNION ALL SELECT b") self._validate_delta_only( diff_delta_only(expr_src, expr_tgt), [ Update(source=expr_src, target=expr_tgt), ], ) expr_src = parse_one("SELECT a FROM t ORDER BY b ASC") expr_tgt = parse_one("SELECT a FROM t ORDER BY b DESC") self._validate_delta_only( diff_delta_only(expr_src, expr_tgt), [ Update( source=expr_src.find(exp.Order).expressions[0], target=expr_tgt.find(exp.Order).expressions[0], ), ], ) expr_src = parse_one("SELECT a, b FROM t ORDER BY c ASC") expr_tgt = parse_one("SELECT b, a FROM t ORDER BY c DESC") self._validate_delta_only( diff_delta_only(expr_src, expr_tgt), [ Update( source=expr_src.find(exp.Order).expressions[0], target=expr_tgt.find(exp.Order).expressions[0], ), Move(source=expr_src.selects[0], target=expr_tgt.selects[1]), ], ) def _validate_delta_only(self, actual_delta, expected_delta): self.assertEqual(set(actual_delta), set(expected_delta))