From b3c7fe6a73484a4d2177c30f951cd11a4916ed56 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Fri, 2 Dec 2022 10:16:32 +0100 Subject: Merging upstream version 10.1.3. Signed-off-by: Daniel Baumann --- tests/test_transpile.py | 120 ++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 110 insertions(+), 10 deletions(-) (limited to 'tests/test_transpile.py') diff --git a/tests/test_transpile.py b/tests/test_transpile.py index 1bd2527..7bf53e5 100644 --- a/tests/test_transpile.py +++ b/tests/test_transpile.py @@ -26,6 +26,7 @@ class TestTranspile(unittest.TestCase): ) self.assertEqual(transpile("SELECT 1 current_date")[0], "SELECT 1 AS current_date") self.assertEqual(transpile("SELECT 1 current_datetime")[0], "SELECT 1 AS current_datetime") + self.assertEqual(transpile("SELECT 1 row")[0], "SELECT 1 AS row") for key in ("union", "filter", "over", "from", "join"): with self.subTest(f"alias {key}"): @@ -38,6 +39,11 @@ class TestTranspile(unittest.TestCase): def test_asc(self): self.validate("SELECT x FROM y ORDER BY x ASC", "SELECT x FROM y ORDER BY x") + def test_unary(self): + self.validate("+++1", "1") + self.validate("+-1", "-1") + self.validate("+- - -1", "- - -1") + def test_paren(self): with self.assertRaises(ParseError): transpile("1 + (2 + 3") @@ -58,7 +64,7 @@ class TestTranspile(unittest.TestCase): ) self.validate( "SELECT FOO, /*x*/\nBAR, /*y*/\nBAZ", - "SELECT\n FOO -- x\n , BAR -- y\n , BAZ", + "SELECT\n FOO /* x */\n , BAR /* y */\n , BAZ", leading_comma=True, pretty=True, ) @@ -78,7 +84,8 @@ class TestTranspile(unittest.TestCase): def test_comments(self): self.validate("SELECT */*comment*/", "SELECT * /* comment */") self.validate( - "SELECT * FROM table /*comment 1*/ /*comment 2*/", "SELECT * FROM table /* comment 1 */" + "SELECT * FROM table /*comment 1*/ /*comment 2*/", + "SELECT * FROM table /* comment 1 */ /* comment 2 */", ) self.validate("SELECT 1 FROM foo -- comment", "SELECT 1 FROM foo /* comment */") self.validate("SELECT --+5\nx FROM foo", "/* +5 */ SELECT x FROM foo") @@ -112,6 +119,53 @@ class TestTranspile(unittest.TestCase): ) self.validate( """ +-- comment 1 +-- comment 2 +-- comment 3 +SELECT * FROM foo + """, + "/* comment 1 */ /* comment 2 */ /* comment 3 */ SELECT * FROM foo", + ) + self.validate( + """ +-- comment 1 +-- comment 2 +-- comment 3 +SELECT * FROM foo""", + """/* comment 1 */ +/* comment 2 */ +/* comment 3 */ +SELECT + * +FROM foo""", + pretty=True, + ) + self.validate( + """ +SELECT * FROM tbl /*line1 +line2 +line3*/ /*another comment*/ where 1=1 -- comment at the end""", + """SELECT * FROM tbl /* line1 +line2 +line3 */ /* another comment */ WHERE 1 = 1 /* comment at the end */""", + ) + self.validate( + """ +SELECT * FROM tbl /*line1 +line2 +line3*/ /*another comment*/ where 1=1 -- comment at the end""", + """SELECT + * +FROM tbl /* line1 +line2 +line3 */ +/* another comment */ +WHERE + 1 = 1 /* comment at the end */""", + pretty=True, + ) + self.validate( + """ /* multi line comment @@ -130,8 +184,8 @@ class TestTranspile(unittest.TestCase): */ SELECT tbl.cola /* comment 1 */ + tbl.colb /* comment 2 */, - CAST(x AS INT), -- comment 3 - y -- comment 4 + CAST(x AS INT), /* comment 3 */ + y /* comment 4 */ FROM bar /* comment 5 */, tbl /* comment 6 */""", read="mysql", pretty=True, @@ -364,33 +418,79 @@ FROM bar /* comment 5 */, tbl /* comment 6 */""", @mock.patch("sqlglot.parser.logger") def test_error_level(self, logger): invalid = "x + 1. (" - errors = [ + expected_messages = [ "Required keyword: 'expressions' missing for . Line 1, Col: 8.\n x + 1. \033[4m(\033[0m", "Expecting ). Line 1, Col: 8.\n x + 1. \033[4m(\033[0m", ] + expected_errors = [ + { + "description": "Required keyword: 'expressions' missing for ", + "line": 1, + "col": 8, + "start_context": "x + 1. ", + "highlight": "(", + "end_context": "", + "into_expression": None, + }, + { + "description": "Expecting )", + "line": 1, + "col": 8, + "start_context": "x + 1. ", + "highlight": "(", + "end_context": "", + "into_expression": None, + }, + ] transpile(invalid, error_level=ErrorLevel.WARN) - for error in errors: + for error in expected_messages: assert_logger_contains(error, logger) with self.assertRaises(ParseError) as ctx: transpile(invalid, error_level=ErrorLevel.IMMEDIATE) - self.assertEqual(str(ctx.exception), errors[0]) + self.assertEqual(str(ctx.exception), expected_messages[0]) + self.assertEqual(ctx.exception.errors[0], expected_errors[0]) with self.assertRaises(ParseError) as ctx: transpile(invalid, error_level=ErrorLevel.RAISE) - self.assertEqual(str(ctx.exception), "\n\n".join(errors)) + self.assertEqual(str(ctx.exception), "\n\n".join(expected_messages)) + self.assertEqual(ctx.exception.errors, expected_errors) more_than_max_errors = "((((" - expected = ( + expected_messages = ( "Expecting ). Line 1, Col: 4.\n (((\033[4m(\033[0m\n\n" "Required keyword: 'this' missing for . Line 1, Col: 4.\n (((\033[4m(\033[0m\n\n" "Expecting ). Line 1, Col: 4.\n (((\033[4m(\033[0m\n\n" "... and 2 more" ) + expected_errors = [ + { + "description": "Expecting )", + "line": 1, + "col": 4, + "start_context": "(((", + "highlight": "(", + "end_context": "", + "into_expression": None, + }, + { + "description": "Required keyword: 'this' missing for ", + "line": 1, + "col": 4, + "start_context": "(((", + "highlight": "(", + "end_context": "", + "into_expression": None, + }, + ] + # Also expect three trailing structured errors that match the first + expected_errors += [expected_errors[0]] * 3 + with self.assertRaises(ParseError) as ctx: transpile(more_than_max_errors, error_level=ErrorLevel.RAISE) - self.assertEqual(str(ctx.exception), expected) + self.assertEqual(str(ctx.exception), expected_messages) + self.assertEqual(ctx.exception.errors, expected_errors) @mock.patch("sqlglot.generator.logger") def test_unsupported_level(self, logger): -- cgit v1.2.3