summaryrefslogtreecommitdiffstats
path: root/tests/test_expressions.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_expressions.py')
-rw-r--r--tests/test_expressions.py47
1 files changed, 39 insertions, 8 deletions
diff --git a/tests/test_expressions.py b/tests/test_expressions.py
index 11f8fd3..ed19ac1 100644
--- a/tests/test_expressions.py
+++ b/tests/test_expressions.py
@@ -22,6 +22,9 @@ class TestExpressions(unittest.TestCase):
pass
def test_eq(self):
+ query = parse_one("SELECT x FROM t")
+ self.assertEqual(query, query.copy())
+
self.assertNotEqual(exp.to_identifier("a"), exp.to_identifier("A"))
self.assertEqual(
@@ -498,6 +501,18 @@ class TestExpressions(unittest.TestCase):
self.assertEqual(expression.transform(fun).sql(), "FUN(a)")
+ def test_transform_with_parent_mutation(self):
+ expression = parse_one("SELECT COUNT(1) FROM table")
+
+ def fun(node):
+ if str(node) == "COUNT(1)":
+ # node gets silently mutated here - its parent points to the filter node
+ return exp.Filter(this=node, expression=exp.Where(this=exp.true()))
+ return node
+
+ transformed = expression.transform(fun)
+ self.assertEqual(transformed.sql(), "SELECT COUNT(1) FILTER(WHERE TRUE) FROM table")
+
def test_transform_multiple_children(self):
expression = parse_one("SELECT * FROM x")
@@ -517,7 +532,6 @@ class TestExpressions(unittest.TestCase):
return node
self.assertEqual(expression.transform(remove_column_b).sql(), "SELECT a FROM x")
- self.assertEqual(expression.transform(lambda _: None), None)
expression = parse_one("CAST(x AS FLOAT)")
@@ -544,6 +558,11 @@ class TestExpressions(unittest.TestCase):
expression.find(exp.Table).replace(parse_one("y"))
self.assertEqual(expression.sql(), "SELECT c, b FROM y")
+ # we try to replace a with a list but a's parent is actually ordered, not the ORDER BY node
+ expression = parse_one("SELECT * FROM x ORDER BY a DESC, c")
+ expression.find(exp.Ordered).this.replace([exp.column("a").asc(), exp.column("b").desc()])
+ self.assertEqual(expression.sql(), "SELECT * FROM x ORDER BY a, b DESC, c")
+
def test_arg_deletion(self):
# Using the pop helper method
expression = parse_one("SELECT a, b FROM x")
@@ -573,10 +592,8 @@ class TestExpressions(unittest.TestCase):
expression = parse_one("SELECT * FROM (SELECT * FROM x)")
self.assertEqual(len(list(expression.walk())), 9)
self.assertEqual(len(list(expression.walk(bfs=False))), 9)
- self.assertTrue(all(isinstance(e, exp.Expression) for e, _, _ in expression.walk()))
- self.assertTrue(
- all(isinstance(e, exp.Expression) for e, _, _ in expression.walk(bfs=False))
- )
+ self.assertTrue(all(isinstance(e, exp.Expression) for e in expression.walk()))
+ self.assertTrue(all(isinstance(e, exp.Expression) for e in expression.walk(bfs=False)))
def test_functions(self):
self.assertIsInstance(parse_one("x LIKE ANY (y)"), exp.Like)
@@ -611,7 +628,9 @@ class TestExpressions(unittest.TestCase):
self.assertIsInstance(parse_one("LEAST(a, b)"), exp.Least)
self.assertIsInstance(parse_one("LIKE(x, y)"), exp.Like)
self.assertIsInstance(parse_one("LN(a)"), exp.Ln)
- self.assertIsInstance(parse_one("LOG10(a)"), exp.Log10)
+ self.assertIsInstance(parse_one("LOG(b, n)"), exp.Log)
+ self.assertIsInstance(parse_one("LOG2(a)"), exp.Log)
+ self.assertIsInstance(parse_one("LOG10(a)"), exp.Log)
self.assertIsInstance(parse_one("MAX(a)"), exp.Max)
self.assertIsInstance(parse_one("MIN(a)"), exp.Min)
self.assertIsInstance(parse_one("MONTH(a)"), exp.Month)
@@ -765,6 +784,15 @@ class TestExpressions(unittest.TestCase):
self.assertRaises(ValueError, exp.Properties.from_dict, {"FORMAT": object})
def test_convert(self):
+ from collections import namedtuple
+
+ PointTuple = namedtuple("Point", ["x", "y"])
+
+ class PointClass:
+ def __init__(self, x=0, y=0):
+ self.x = x
+ self.y = y
+
for value, expected in [
(1, "1"),
("1", "'1'"),
@@ -775,14 +803,17 @@ class TestExpressions(unittest.TestCase):
({"x": None}, "MAP(ARRAY('x'), ARRAY(NULL))"),
(
datetime.datetime(2022, 10, 1, 1, 1, 1, 1),
- "TIME_STR_TO_TIME('2022-10-01T01:01:01.000001+00:00')",
+ "TIME_STR_TO_TIME('2022-10-01 01:01:01.000001+00:00')",
),
(
datetime.datetime(2022, 10, 1, 1, 1, 1, tzinfo=datetime.timezone.utc),
- "TIME_STR_TO_TIME('2022-10-01T01:01:01+00:00')",
+ "TIME_STR_TO_TIME('2022-10-01 01:01:01+00:00')",
),
(datetime.date(2022, 10, 1), "DATE_STR_TO_DATE('2022-10-01')"),
(math.nan, "NULL"),
+ (b"\x00\x00\x00\x00\x00\x00\x07\xd3", "2003"),
+ (PointTuple(1, 2), "STRUCT(1 AS x, 2 AS y)"),
+ (PointClass(1, 2), "STRUCT(1 AS x, 2 AS y)"),
]:
with self.subTest(value):
self.assertEqual(exp.convert(value).sql(), expected)