summaryrefslogtreecommitdiffstats
path: root/tests/test_transpile.py
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2022-12-02 09:16:32 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2022-12-02 09:16:32 +0000
commitb3c7fe6a73484a4d2177c30f951cd11a4916ed56 (patch)
tree7192898cb782bbb0b9b13bd8d6341fe4434f0f31 /tests/test_transpile.py
parentReleasing debian version 10.0.8-1. (diff)
downloadsqlglot-b3c7fe6a73484a4d2177c30f951cd11a4916ed56.tar.xz
sqlglot-b3c7fe6a73484a4d2177c30f951cd11a4916ed56.zip
Merging upstream version 10.1.3.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'tests/test_transpile.py')
-rw-r--r--tests/test_transpile.py120
1 files changed, 110 insertions, 10 deletions
diff --git a/tests/test_transpile.py b/tests/test_transpile.py
index 1bd2527..7bf53e5 100644
--- a/tests/test_transpile.py
+++ b/tests/test_transpile.py
@@ -26,6 +26,7 @@ class TestTranspile(unittest.TestCase):
)
self.assertEqual(transpile("SELECT 1 current_date")[0], "SELECT 1 AS current_date")
self.assertEqual(transpile("SELECT 1 current_datetime")[0], "SELECT 1 AS current_datetime")
+ self.assertEqual(transpile("SELECT 1 row")[0], "SELECT 1 AS row")
for key in ("union", "filter", "over", "from", "join"):
with self.subTest(f"alias {key}"):
@@ -38,6 +39,11 @@ class TestTranspile(unittest.TestCase):
def test_asc(self):
self.validate("SELECT x FROM y ORDER BY x ASC", "SELECT x FROM y ORDER BY x")
+ def test_unary(self):
+ self.validate("+++1", "1")
+ self.validate("+-1", "-1")
+ self.validate("+- - -1", "- - -1")
+
def test_paren(self):
with self.assertRaises(ParseError):
transpile("1 + (2 + 3")
@@ -58,7 +64,7 @@ class TestTranspile(unittest.TestCase):
)
self.validate(
"SELECT FOO, /*x*/\nBAR, /*y*/\nBAZ",
- "SELECT\n FOO -- x\n , BAR -- y\n , BAZ",
+ "SELECT\n FOO /* x */\n , BAR /* y */\n , BAZ",
leading_comma=True,
pretty=True,
)
@@ -78,7 +84,8 @@ class TestTranspile(unittest.TestCase):
def test_comments(self):
self.validate("SELECT */*comment*/", "SELECT * /* comment */")
self.validate(
- "SELECT * FROM table /*comment 1*/ /*comment 2*/", "SELECT * FROM table /* comment 1 */"
+ "SELECT * FROM table /*comment 1*/ /*comment 2*/",
+ "SELECT * FROM table /* comment 1 */ /* comment 2 */",
)
self.validate("SELECT 1 FROM foo -- comment", "SELECT 1 FROM foo /* comment */")
self.validate("SELECT --+5\nx FROM foo", "/* +5 */ SELECT x FROM foo")
@@ -112,6 +119,53 @@ class TestTranspile(unittest.TestCase):
)
self.validate(
"""
+-- comment 1
+-- comment 2
+-- comment 3
+SELECT * FROM foo
+ """,
+ "/* comment 1 */ /* comment 2 */ /* comment 3 */ SELECT * FROM foo",
+ )
+ self.validate(
+ """
+-- comment 1
+-- comment 2
+-- comment 3
+SELECT * FROM foo""",
+ """/* comment 1 */
+/* comment 2 */
+/* comment 3 */
+SELECT
+ *
+FROM foo""",
+ pretty=True,
+ )
+ self.validate(
+ """
+SELECT * FROM tbl /*line1
+line2
+line3*/ /*another comment*/ where 1=1 -- comment at the end""",
+ """SELECT * FROM tbl /* line1
+line2
+line3 */ /* another comment */ WHERE 1 = 1 /* comment at the end */""",
+ )
+ self.validate(
+ """
+SELECT * FROM tbl /*line1
+line2
+line3*/ /*another comment*/ where 1=1 -- comment at the end""",
+ """SELECT
+ *
+FROM tbl /* line1
+line2
+line3 */
+/* another comment */
+WHERE
+ 1 = 1 /* comment at the end */""",
+ pretty=True,
+ )
+ self.validate(
+ """
/* multi
line
comment
@@ -130,8 +184,8 @@ class TestTranspile(unittest.TestCase):
*/
SELECT
tbl.cola /* comment 1 */ + tbl.colb /* comment 2 */,
- CAST(x AS INT), -- comment 3
- y -- comment 4
+ CAST(x AS INT), /* comment 3 */
+ y /* comment 4 */
FROM bar /* comment 5 */, tbl /* comment 6 */""",
read="mysql",
pretty=True,
@@ -364,33 +418,79 @@ FROM bar /* comment 5 */, tbl /* comment 6 */""",
@mock.patch("sqlglot.parser.logger")
def test_error_level(self, logger):
invalid = "x + 1. ("
- errors = [
+ expected_messages = [
"Required keyword: 'expressions' missing for <class 'sqlglot.expressions.Aliases'>. Line 1, Col: 8.\n x + 1. \033[4m(\033[0m",
"Expecting ). Line 1, Col: 8.\n x + 1. \033[4m(\033[0m",
]
+ expected_errors = [
+ {
+ "description": "Required keyword: 'expressions' missing for <class 'sqlglot.expressions.Aliases'>",
+ "line": 1,
+ "col": 8,
+ "start_context": "x + 1. ",
+ "highlight": "(",
+ "end_context": "",
+ "into_expression": None,
+ },
+ {
+ "description": "Expecting )",
+ "line": 1,
+ "col": 8,
+ "start_context": "x + 1. ",
+ "highlight": "(",
+ "end_context": "",
+ "into_expression": None,
+ },
+ ]
transpile(invalid, error_level=ErrorLevel.WARN)
- for error in errors:
+ for error in expected_messages:
assert_logger_contains(error, logger)
with self.assertRaises(ParseError) as ctx:
transpile(invalid, error_level=ErrorLevel.IMMEDIATE)
- self.assertEqual(str(ctx.exception), errors[0])
+ self.assertEqual(str(ctx.exception), expected_messages[0])
+ self.assertEqual(ctx.exception.errors[0], expected_errors[0])
with self.assertRaises(ParseError) as ctx:
transpile(invalid, error_level=ErrorLevel.RAISE)
- self.assertEqual(str(ctx.exception), "\n\n".join(errors))
+ self.assertEqual(str(ctx.exception), "\n\n".join(expected_messages))
+ self.assertEqual(ctx.exception.errors, expected_errors)
more_than_max_errors = "(((("
- expected = (
+ expected_messages = (
"Expecting ). Line 1, Col: 4.\n (((\033[4m(\033[0m\n\n"
"Required keyword: 'this' missing for <class 'sqlglot.expressions.Paren'>. Line 1, Col: 4.\n (((\033[4m(\033[0m\n\n"
"Expecting ). Line 1, Col: 4.\n (((\033[4m(\033[0m\n\n"
"... and 2 more"
)
+ expected_errors = [
+ {
+ "description": "Expecting )",
+ "line": 1,
+ "col": 4,
+ "start_context": "(((",
+ "highlight": "(",
+ "end_context": "",
+ "into_expression": None,
+ },
+ {
+ "description": "Required keyword: 'this' missing for <class 'sqlglot.expressions.Paren'>",
+ "line": 1,
+ "col": 4,
+ "start_context": "(((",
+ "highlight": "(",
+ "end_context": "",
+ "into_expression": None,
+ },
+ ]
+ # Also expect three trailing structured errors that match the first
+ expected_errors += [expected_errors[0]] * 3
+
with self.assertRaises(ParseError) as ctx:
transpile(more_than_max_errors, error_level=ErrorLevel.RAISE)
- self.assertEqual(str(ctx.exception), expected)
+ self.assertEqual(str(ctx.exception), expected_messages)
+ self.assertEqual(ctx.exception.errors, expected_errors)
@mock.patch("sqlglot.generator.logger")
def test_unsupported_level(self, logger):