diff options
Diffstat (limited to 'tests/test_transpile.py')
-rw-r--r-- | tests/test_transpile.py | 94 |
1 files changed, 84 insertions, 10 deletions
diff --git a/tests/test_transpile.py b/tests/test_transpile.py index 36e0aa6..d68f6f8 100644 --- a/tests/test_transpile.py +++ b/tests/test_transpile.py @@ -20,6 +20,9 @@ class TestTranspile(unittest.TestCase): self.assertEqual(transpile(sql, **kwargs)[0], target) def test_alias(self): + self.assertEqual(transpile("SELECT SUM(y) KEEP")[0], "SELECT SUM(y) AS KEEP") + self.assertEqual(transpile("SELECT 1 overwrite")[0], "SELECT 1 AS overwrite") + self.assertEqual(transpile("SELECT 1 is")[0], "SELECT 1 AS is") self.assertEqual(transpile("SELECT 1 current_time")[0], "SELECT 1 AS current_time") self.assertEqual( transpile("SELECT 1 current_timestamp")[0], "SELECT 1 AS current_timestamp" @@ -87,6 +90,7 @@ class TestTranspile(unittest.TestCase): self.validate("SELECT 3>=3", "SELECT 3 >= 3") def test_comments(self): + self.validate("SELECT 1 /*/2 */", "SELECT 1 /* /2 */") self.validate("SELECT */*comment*/", "SELECT * /* comment */") self.validate( "SELECT * FROM table /*comment 1*/ /*comment 2*/", @@ -200,6 +204,65 @@ FROM bar /* comment 5 */, tbl /* comment 6 */""", read="mysql", pretty=True, ) + self.validate( + """ + SELECT a FROM b + WHERE foo + -- comment 1 + AND bar + -- comment 2 + AND bla; + """, + "SELECT a FROM b WHERE foo AND /* comment 1 */ bar AND /* comment 2 */ bla", + ) + self.validate( + """ + SELECT a FROM b WHERE foo + -- comment 1 + """, + "SELECT a FROM b WHERE foo /* comment 1 */", + ) + self.validate( + """ + select a from b + where foo + -- comment 1 + and bar + -- comment 2 + and bla + """, + """SELECT + a +FROM b +WHERE + foo /* comment 1 */ AND bar AND bla /* comment 2 */""", + pretty=True, + ) + self.validate( + """ + -- test + WITH v AS ( + SELECT + 1 AS literal + ) + SELECT + * + FROM v + """, + """/* test */ +WITH v AS ( + SELECT + 1 AS literal +) +SELECT + * +FROM v""", + pretty=True, + ) + self.validate( + "(/* 1 */ 1 ) /* 2 */", + "(1) /* 1 */ /* 2 */", + ) def test_types(self): self.validate("INT 1", "CAST(1 AS INT)") @@ -288,7 +351,6 @@ FROM bar /* comment 5 */, tbl /* comment 6 */""", "ALTER TABLE integers ADD k INTEGER", "ALTER TABLE integers ADD COLUMN k INT", ) - self.validate("ALTER TABLE integers DROP k", "ALTER TABLE integers DROP COLUMN k") self.validate( "ALTER TABLE integers ALTER i SET DATA TYPE VARCHAR", "ALTER TABLE integers ALTER COLUMN i TYPE VARCHAR", @@ -299,6 +361,11 @@ FROM bar /* comment 5 */, tbl /* comment 6 */""", ) def test_time(self): + self.validate("INTERVAL '1 day'", "INTERVAL '1' day") + self.validate("INTERVAL '1 days' * 5", "INTERVAL '1' days * 5") + self.validate("5 * INTERVAL '1 day'", "5 * INTERVAL '1' day") + self.validate("INTERVAL 1 day", "INTERVAL '1' day") + self.validate("INTERVAL 2 months", "INTERVAL '2' months") self.validate("TIMESTAMP '2020-01-01'", "CAST('2020-01-01' AS TIMESTAMP)") self.validate("TIMESTAMP WITH TIME ZONE '2020-01-01'", "CAST('2020-01-01' AS TIMESTAMPTZ)") self.validate( @@ -431,6 +498,13 @@ FROM bar /* comment 5 */, tbl /* comment 6 */""", mock_logger.warning.assert_any_call("Applying array index offset (%s)", 1) mock_logger.warning.assert_any_call("Applying array index offset (%s)", -1) + self.validate("x[x - 1]", "x[x - 1]", write="presto", identity=False) + self.validate( + "x[array_size(y) - 1]", "x[CARDINALITY(y) - 1 + 1]", write="presto", identity=False + ) + self.validate("x[3 - 1]", "x[3]", write="presto", identity=False) + self.validate("MAP(a, b)[0]", "MAP(a, b)[0]", write="presto", identity=False) + def test_identify_lambda(self): self.validate("x(y -> y)", 'X("y" -> "y")', identify=True) @@ -467,14 +541,14 @@ FROM bar /* comment 5 */, tbl /* comment 6 */""", def test_error_level(self, logger): invalid = "x + 1. (" expected_messages = [ - "Required keyword: 'expressions' missing for <class 'sqlglot.expressions.Aliases'>. Line 1, Col: 8.\n x + 1. \033[4m(\033[0m", - "Expecting ). Line 1, Col: 8.\n x + 1. \033[4m(\033[0m", + "Required keyword: 'expressions' missing for <class 'sqlglot.expressions.Aliases'>. Line 1, Col: 9.\n x + 1. \033[4m(\033[0m", + "Expecting ). Line 1, Col: 9.\n x + 1. \033[4m(\033[0m", ] expected_errors = [ { "description": "Required keyword: 'expressions' missing for <class 'sqlglot.expressions.Aliases'>", "line": 1, - "col": 8, + "col": 9, "start_context": "x + 1. ", "highlight": "(", "end_context": "", @@ -483,7 +557,7 @@ FROM bar /* comment 5 */, tbl /* comment 6 */""", { "description": "Expecting )", "line": 1, - "col": 8, + "col": 9, "start_context": "x + 1. ", "highlight": "(", "end_context": "", @@ -507,16 +581,16 @@ FROM bar /* comment 5 */, tbl /* comment 6 */""", more_than_max_errors = "((((" expected_messages = ( - "Required keyword: 'this' missing for <class 'sqlglot.expressions.Paren'>. Line 1, Col: 4.\n (((\033[4m(\033[0m\n\n" - "Expecting ). Line 1, Col: 4.\n (((\033[4m(\033[0m\n\n" - "Expecting ). Line 1, Col: 4.\n (((\033[4m(\033[0m\n\n" + "Required keyword: 'this' missing for <class 'sqlglot.expressions.Paren'>. Line 1, Col: 5.\n (((\033[4m(\033[0m\n\n" + "Expecting ). Line 1, Col: 5.\n (((\033[4m(\033[0m\n\n" + "Expecting ). Line 1, Col: 5.\n (((\033[4m(\033[0m\n\n" "... and 2 more" ) expected_errors = [ { "description": "Required keyword: 'this' missing for <class 'sqlglot.expressions.Paren'>", "line": 1, - "col": 4, + "col": 5, "start_context": "(((", "highlight": "(", "end_context": "", @@ -525,7 +599,7 @@ FROM bar /* comment 5 */, tbl /* comment 6 */""", { "description": "Expecting )", "line": 1, - "col": 4, + "col": 5, "start_context": "(((", "highlight": "(", "end_context": "", |