import unittest

from sqlglot import exp, parse_one
from sqlglot.diff import Insert, Keep, Move, Remove, Update, diff
from sqlglot.expressions import Join, to_table


class TestDiff(unittest.TestCase):
    def test_simple(self):
        self._validate_delta_only(
            diff(parse_one("SELECT a + b"), parse_one("SELECT a - b")),
            [
                Remove(parse_one("a + b")),  # the Add node
                Insert(parse_one("a - b")),  # the Sub node
            ],
        )

        self._validate_delta_only(
            diff(parse_one("SELECT a, b, c"), parse_one("SELECT a, c")),
            [
                Remove(parse_one("b")),  # the Column node
            ],
        )

        self._validate_delta_only(
            diff(parse_one("SELECT a, b"), parse_one("SELECT a, b, c")),
            [
                Insert(parse_one("c")),  # the Column node
            ],
        )

        self._validate_delta_only(
            diff(
                parse_one("SELECT a FROM table_one"),
                parse_one("SELECT a FROM table_two"),
            ),
            [
                Update(
                    to_table("table_one", quoted=False),
                    to_table("table_two", quoted=False),
                ),  # the Table node
            ],
        )

    def test_lambda(self):
        self._validate_delta_only(
            diff(parse_one("SELECT a, b, c, x(a -> a)"), parse_one("SELECT a, b, c, x(b -> b)")),
            [
                Update(
                    exp.Lambda(this=exp.to_identifier("a"), expressions=[exp.to_identifier("a")]),
                    exp.Lambda(this=exp.to_identifier("b"), expressions=[exp.to_identifier("b")]),
                ),
            ],
        )

    def test_udf(self):
        self._validate_delta_only(
            diff(parse_one('SELECT a, b, "my.udf1"()'), parse_one('SELECT a, b, "my.udf2"()')),
            [
                Insert(parse_one('"my.udf2"()')),
                Remove(parse_one('"my.udf1"()')),
            ],
        )
        self._validate_delta_only(
            diff(
                parse_one('SELECT a, b, "my.udf"(x, y, z)'),
                parse_one('SELECT a, b, "my.udf"(x, y, w)'),
            ),
            [
                Insert(exp.column("w")),
                Remove(exp.column("z")),
            ],
        )

    def test_node_position_changed(self):
        self._validate_delta_only(
            diff(parse_one("SELECT a, b, c"), parse_one("SELECT c, a, b")),
            [
                Move(parse_one("c")),  # the Column node
            ],
        )

        self._validate_delta_only(
            diff(parse_one("SELECT a + b"), parse_one("SELECT b + a")),
            [
                Move(parse_one("a")),  # the Column node
            ],
        )

        self._validate_delta_only(
            diff(parse_one("SELECT aaaa AND bbbb"), parse_one("SELECT bbbb AND aaaa")),
            [
                Move(parse_one("aaaa")),  # the Column node
            ],
        )

        self._validate_delta_only(
            diff(
                parse_one("SELECT aaaa OR bbbb OR cccc"),
                parse_one("SELECT cccc OR bbbb OR aaaa"),
            ),
            [
                Move(parse_one("aaaa")),  # the Column node
                Move(parse_one("cccc")),  # the Column node
            ],
        )

    def test_cte(self):
        expr_src = """
            WITH
                cte1 AS (SELECT a, b, LOWER(c) AS c FROM table_one WHERE d = 'filter'),
                cte2 AS (SELECT d, e, f FROM table_two)
            SELECT a, b, d, e FROM cte1 JOIN cte2 ON f = c
        """
        expr_tgt = """
            WITH
                cte1 AS (SELECT a, b, c FROM table_one WHERE d = 'different_filter'),
                cte2 AS (SELECT d, e, f FROM table_two)
            SELECT a, b, d, e FROM cte1 JOIN cte2 ON f = c
        """

        self._validate_delta_only(
            diff(parse_one(expr_src), parse_one(expr_tgt)),
            [
                Remove(parse_one("LOWER(c) AS c")),  # the Alias node
                Remove(parse_one("LOWER(c)")),  # the Lower node
                Remove(parse_one("'filter'")),  # the Literal node
                Insert(parse_one("'different_filter'")),  # the Literal node
            ],
        )

    def test_join(self):
        expr_src = "SELECT a, b FROM t1 LEFT JOIN t2 ON t1.key = t2.key"
        expr_tgt = "SELECT a, b FROM t1 RIGHT JOIN t2 ON t1.key = t2.key"

        changes = diff(parse_one(expr_src), parse_one(expr_tgt))
        changes = _delta_only(changes)

        self.assertEqual(len(changes), 2)
        self.assertTrue(isinstance(changes[0], Remove))
        self.assertTrue(isinstance(changes[1], Insert))
        self.assertTrue(all(isinstance(c.expression, Join) for c in changes))

    def test_window_functions(self):
        expr_src = parse_one("SELECT ROW_NUMBER() OVER (PARTITION BY a ORDER BY b)")
        expr_tgt = parse_one("SELECT RANK() OVER (PARTITION BY a ORDER BY b)")

        self._validate_delta_only(diff(expr_src, expr_src), [])

        self._validate_delta_only(
            diff(expr_src, expr_tgt),
            [
                Remove(parse_one("ROW_NUMBER()")),  # the Anonymous node
                Insert(parse_one("RANK()")),  # the Anonymous node
            ],
        )

    def test_pre_matchings(self):
        expr_src = parse_one("SELECT 1")
        expr_tgt = parse_one("SELECT 1, 2, 3, 4")

        self._validate_delta_only(
            diff(expr_src, expr_tgt),
            [
                Remove(expr_src),
                Insert(expr_tgt),
                Insert(exp.Literal.number(2)),
                Insert(exp.Literal.number(3)),
                Insert(exp.Literal.number(4)),
            ],
        )

        self._validate_delta_only(
            diff(expr_src, expr_tgt, matchings=[(expr_src, expr_tgt)]),
            [
                Insert(exp.Literal.number(2)),
                Insert(exp.Literal.number(3)),
                Insert(exp.Literal.number(4)),
            ],
        )

        with self.assertRaises(ValueError):
            diff(expr_src, expr_tgt, matchings=[(expr_src, expr_tgt), (expr_src, expr_tgt)])

    def test_identifier(self):
        expr_src = parse_one("SELECT a FROM tbl")
        expr_tgt = parse_one("SELECT a, tbl.b from tbl")

        self._validate_delta_only(
            diff(expr_src, expr_tgt),
            [
                Insert(expression=exp.to_column("tbl.b")),
            ],
        )

    def _validate_delta_only(self, actual_diff, expected_delta):
        actual_delta = _delta_only(actual_diff)
        self.assertEqual(set(actual_delta), set(expected_delta))


def _delta_only(changes):
    return [d for d in changes if not isinstance(d, Keep)]