diff options
Diffstat (limited to 'tests/test_expressions.py')
-rw-r--r-- | tests/test_expressions.py | 47 |
1 files changed, 39 insertions, 8 deletions
diff --git a/tests/test_expressions.py b/tests/test_expressions.py index 11f8fd3..ed19ac1 100644 --- a/tests/test_expressions.py +++ b/tests/test_expressions.py @@ -22,6 +22,9 @@ class TestExpressions(unittest.TestCase): pass def test_eq(self): + query = parse_one("SELECT x FROM t") + self.assertEqual(query, query.copy()) + self.assertNotEqual(exp.to_identifier("a"), exp.to_identifier("A")) self.assertEqual( @@ -498,6 +501,18 @@ class TestExpressions(unittest.TestCase): self.assertEqual(expression.transform(fun).sql(), "FUN(a)") + def test_transform_with_parent_mutation(self): + expression = parse_one("SELECT COUNT(1) FROM table") + + def fun(node): + if str(node) == "COUNT(1)": + # node gets silently mutated here - its parent points to the filter node + return exp.Filter(this=node, expression=exp.Where(this=exp.true())) + return node + + transformed = expression.transform(fun) + self.assertEqual(transformed.sql(), "SELECT COUNT(1) FILTER(WHERE TRUE) FROM table") + def test_transform_multiple_children(self): expression = parse_one("SELECT * FROM x") @@ -517,7 +532,6 @@ class TestExpressions(unittest.TestCase): return node self.assertEqual(expression.transform(remove_column_b).sql(), "SELECT a FROM x") - self.assertEqual(expression.transform(lambda _: None), None) expression = parse_one("CAST(x AS FLOAT)") @@ -544,6 +558,11 @@ class TestExpressions(unittest.TestCase): expression.find(exp.Table).replace(parse_one("y")) self.assertEqual(expression.sql(), "SELECT c, b FROM y") + # we try to replace a with a list but a's parent is actually ordered, not the ORDER BY node + expression = parse_one("SELECT * FROM x ORDER BY a DESC, c") + expression.find(exp.Ordered).this.replace([exp.column("a").asc(), exp.column("b").desc()]) + self.assertEqual(expression.sql(), "SELECT * FROM x ORDER BY a, b DESC, c") + def test_arg_deletion(self): # Using the pop helper method expression = parse_one("SELECT a, b FROM x") @@ -573,10 +592,8 @@ class TestExpressions(unittest.TestCase): expression = parse_one("SELECT * FROM (SELECT * FROM x)") self.assertEqual(len(list(expression.walk())), 9) self.assertEqual(len(list(expression.walk(bfs=False))), 9) - self.assertTrue(all(isinstance(e, exp.Expression) for e, _, _ in expression.walk())) - self.assertTrue( - all(isinstance(e, exp.Expression) for e, _, _ in expression.walk(bfs=False)) - ) + self.assertTrue(all(isinstance(e, exp.Expression) for e in expression.walk())) + self.assertTrue(all(isinstance(e, exp.Expression) for e in expression.walk(bfs=False))) def test_functions(self): self.assertIsInstance(parse_one("x LIKE ANY (y)"), exp.Like) @@ -611,7 +628,9 @@ class TestExpressions(unittest.TestCase): self.assertIsInstance(parse_one("LEAST(a, b)"), exp.Least) self.assertIsInstance(parse_one("LIKE(x, y)"), exp.Like) self.assertIsInstance(parse_one("LN(a)"), exp.Ln) - self.assertIsInstance(parse_one("LOG10(a)"), exp.Log10) + self.assertIsInstance(parse_one("LOG(b, n)"), exp.Log) + self.assertIsInstance(parse_one("LOG2(a)"), exp.Log) + self.assertIsInstance(parse_one("LOG10(a)"), exp.Log) self.assertIsInstance(parse_one("MAX(a)"), exp.Max) self.assertIsInstance(parse_one("MIN(a)"), exp.Min) self.assertIsInstance(parse_one("MONTH(a)"), exp.Month) @@ -765,6 +784,15 @@ class TestExpressions(unittest.TestCase): self.assertRaises(ValueError, exp.Properties.from_dict, {"FORMAT": object}) def test_convert(self): + from collections import namedtuple + + PointTuple = namedtuple("Point", ["x", "y"]) + + class PointClass: + def __init__(self, x=0, y=0): + self.x = x + self.y = y + for value, expected in [ (1, "1"), ("1", "'1'"), @@ -775,14 +803,17 @@ class TestExpressions(unittest.TestCase): ({"x": None}, "MAP(ARRAY('x'), ARRAY(NULL))"), ( datetime.datetime(2022, 10, 1, 1, 1, 1, 1), - "TIME_STR_TO_TIME('2022-10-01T01:01:01.000001+00:00')", + "TIME_STR_TO_TIME('2022-10-01 01:01:01.000001+00:00')", ), ( datetime.datetime(2022, 10, 1, 1, 1, 1, tzinfo=datetime.timezone.utc), - "TIME_STR_TO_TIME('2022-10-01T01:01:01+00:00')", + "TIME_STR_TO_TIME('2022-10-01 01:01:01+00:00')", ), (datetime.date(2022, 10, 1), "DATE_STR_TO_DATE('2022-10-01')"), (math.nan, "NULL"), + (b"\x00\x00\x00\x00\x00\x00\x07\xd3", "2003"), + (PointTuple(1, 2), "STRUCT(1 AS x, 2 AS y)"), + (PointClass(1, 2), "STRUCT(1 AS x, 2 AS y)"), ]: with self.subTest(value): self.assertEqual(exp.convert(value).sql(), expected) |