diff options
Diffstat (limited to 'tests/test_expressions.py')
-rw-r--r-- | tests/test_expressions.py | 82 |
1 files changed, 62 insertions, 20 deletions
diff --git a/tests/test_expressions.py b/tests/test_expressions.py index adfd329..63371d8 100644 --- a/tests/test_expressions.py +++ b/tests/test_expressions.py @@ -30,7 +30,9 @@ class TestExpressions(unittest.TestCase): self.assertEqual(parse_one("TO_DATE(x)", read="hive"), parse_one("ts_or_ds_to_date(x)")) self.assertEqual(exp.Table(pivots=[]), exp.Table()) self.assertNotEqual(exp.Table(pivots=[None]), exp.Table()) - self.assertEqual(exp.DataType.build("int"), exp.DataType(this=exp.DataType.Type.INT, nested=False)) + self.assertEqual( + exp.DataType.build("int"), exp.DataType(this=exp.DataType.Type.INT, nested=False) + ) def test_find(self): expression = parse_one("CREATE TABLE x STORED AS PARQUET AS SELECT * FROM y") @@ -89,7 +91,9 @@ class TestExpressions(unittest.TestCase): self.assertIsNone(column.find_ancestor(exp.Join)) def test_alias_or_name(self): - expression = parse_one("SELECT a, b AS B, c + d AS e, *, 'zz', 'zz' AS z FROM foo as bar, baz") + expression = parse_one( + "SELECT a, b AS B, c + d AS e, *, 'zz', 'zz' AS z FROM foo as bar, baz" + ) self.assertEqual( [e.alias_or_name for e in expression.expressions], ["a", "B", "e", "*", "zz", "z"], @@ -166,7 +170,9 @@ class TestExpressions(unittest.TestCase): "SELECT * FROM foo WHERE ? > 100", ) self.assertEqual( - exp.replace_placeholders(parse_one("select * from :name WHERE ? > 100"), another_name="bla").sql(), + exp.replace_placeholders( + parse_one("select * from :name WHERE ? > 100"), another_name="bla" + ).sql(), "SELECT * FROM :name WHERE ? > 100", ) self.assertEqual( @@ -183,7 +189,9 @@ class TestExpressions(unittest.TestCase): ) def test_named_selects(self): - expression = parse_one("SELECT a, b AS B, c + d AS e, *, 'zz', 'zz' AS z FROM foo as bar, baz") + expression = parse_one( + "SELECT a, b AS B, c + d AS e, *, 'zz', 'zz' AS z FROM foo as bar, baz" + ) self.assertEqual(expression.named_selects, ["a", "B", "e", "*", "zz", "z"]) expression = parse_one( @@ -367,7 +375,9 @@ class TestExpressions(unittest.TestCase): self.assertEqual(len(list(expression.walk())), 9) self.assertEqual(len(list(expression.walk(bfs=False))), 9) self.assertTrue(all(isinstance(e, exp.Expression) for e, _, _ in expression.walk())) - self.assertTrue(all(isinstance(e, exp.Expression) for e, _, _ in expression.walk(bfs=False))) + self.assertTrue( + all(isinstance(e, exp.Expression) for e, _, _ in expression.walk(bfs=False)) + ) def test_functions(self): self.assertIsInstance(parse_one("ABS(a)"), exp.Abs) @@ -512,14 +522,21 @@ class TestExpressions(unittest.TestCase): ), exp.Properties( expressions=[ - exp.FileFormatProperty(this=exp.Literal.string("FORMAT"), value=exp.Literal.string("parquet")), + exp.FileFormatProperty( + this=exp.Literal.string("FORMAT"), value=exp.Literal.string("parquet") + ), exp.PartitionedByProperty( this=exp.Literal.string("PARTITIONED_BY"), - value=exp.Tuple(expressions=[exp.to_identifier("a"), exp.to_identifier("b")]), + value=exp.Tuple( + expressions=[exp.to_identifier("a"), exp.to_identifier("b")] + ), + ), + exp.AnonymousProperty( + this=exp.Literal.string("custom"), value=exp.Literal.number(1) ), - exp.AnonymousProperty(this=exp.Literal.string("custom"), value=exp.Literal.number(1)), exp.TableFormatProperty( - this=exp.Literal.string("TABLE_FORMAT"), value=exp.to_identifier("test_format") + this=exp.Literal.string("TABLE_FORMAT"), + value=exp.to_identifier("test_format"), ), exp.EngineProperty(this=exp.Literal.string("ENGINE"), value=exp.NULL), exp.CollateProperty(this=exp.Literal.string("COLLATE"), value=exp.TRUE), @@ -538,7 +555,10 @@ class TestExpressions(unittest.TestCase): ((1, "2", None), "(1, '2', NULL)"), ([1, "2", None], "ARRAY(1, '2', NULL)"), ({"x": None}, "MAP('x', NULL)"), - (datetime.datetime(2022, 10, 1, 1, 1, 1), "TIME_STR_TO_TIME('2022-10-01 01:01:01.000000')"), + ( + datetime.datetime(2022, 10, 1, 1, 1, 1), + "TIME_STR_TO_TIME('2022-10-01 01:01:01.000000')", + ), ( datetime.datetime(2022, 10, 1, 1, 1, 1, tzinfo=datetime.timezone.utc), "TIME_STR_TO_TIME('2022-10-01 01:01:01.000000+0000')", @@ -548,30 +568,48 @@ class TestExpressions(unittest.TestCase): with self.subTest(value): self.assertEqual(exp.convert(value).sql(), expected) - def test_annotation_alias(self): - sql = "SELECT a, b AS B, c # comment, d AS D # another_comment FROM foo" + def test_comment_alias(self): + sql = """ + SELECT + a, + b AS B, + c, /*comment*/ + d AS D, -- another comment + CAST(x AS INT) -- final comment + FROM foo + """ expression = parse_one(sql) self.assertEqual( [e.alias_or_name for e in expression.expressions], - ["a", "B", "c", "D"], + ["a", "B", "c", "D", "x"], + ) + self.assertEqual( + expression.sql(), + "SELECT a, b AS B, c /* comment */, d AS D /* another comment */, CAST(x AS INT) /* final comment */ FROM foo", + ) + self.assertEqual( + expression.sql(comments=False), + "SELECT a, b AS B, c, d AS D, CAST(x AS INT) FROM foo", ) - self.assertEqual(expression.sql(), "SELECT a, b AS B, c, d AS D") - self.assertEqual(expression.expressions[2].name, "comment") self.assertEqual( - expression.sql(pretty=True, annotations=False), + expression.sql(pretty=True, comments=False), """SELECT a, b AS B, c, - d AS D""", + d AS D, + CAST(x AS INT) +FROM foo""", ) self.assertEqual( expression.sql(pretty=True), """SELECT a, b AS B, - c # comment, - d AS D # another_comment FROM foo""", + c, -- comment + d AS D, -- another comment + CAST(x AS INT) -- final comment +FROM foo""", ) def test_to_table(self): @@ -605,5 +643,9 @@ class TestExpressions(unittest.TestCase): self.assertIsInstance(expression, exp.Union) self.assertEqual(expression.named_selects, ["cola", "colb"]) self.assertEqual( - expression.selects, [exp.Column(this=exp.to_identifier("cola")), exp.Column(this=exp.to_identifier("colb"))] + expression.selects, + [ + exp.Column(this=exp.to_identifier("cola")), + exp.Column(this=exp.to_identifier("colb")), + ], ) |