diff options
Diffstat (limited to 'tests/test_transpile.py')
-rw-r--r-- | tests/test_transpile.py | 51 |
1 files changed, 27 insertions, 24 deletions
diff --git a/tests/test_transpile.py b/tests/test_transpile.py index 28bcc7a..4bec2ac 100644 --- a/tests/test_transpile.py +++ b/tests/test_transpile.py @@ -42,6 +42,20 @@ class TestTranspile(unittest.TestCase): "SELECT * FROM x WHERE a = ANY (SELECT 1)", ) + def test_leading_comma(self): + self.validate( + "SELECT FOO, BAR, BAZ", + "SELECT\n FOO\n , BAR\n , BAZ", + leading_comma=True, + pretty=True, + ) + # without pretty, this should be a no-op + self.validate( + "SELECT FOO, BAR, BAZ", + "SELECT FOO, BAR, BAZ", + leading_comma=True, + ) + def test_space(self): self.validate("SELECT MIN(3)>MIN(2)", "SELECT MIN(3) > MIN(2)") self.validate("SELECT MIN(3)>=MIN(2)", "SELECT MIN(3) >= MIN(2)") @@ -108,6 +122,11 @@ class TestTranspile(unittest.TestCase): "extract(month from '2021-01-31'::timestamp without time zone)", "EXTRACT(month FROM CAST('2021-01-31' AS TIMESTAMP))", ) + self.validate("extract(week from current_date + 2)", "EXTRACT(week FROM CURRENT_DATE + 2)") + self.validate( + "EXTRACT(minute FROM datetime1 - datetime2)", + "EXTRACT(minute FROM datetime1 - datetime2)", + ) def test_if(self): self.validate( @@ -122,18 +141,14 @@ class TestTranspile(unittest.TestCase): "SELECT IF a > 1 THEN b ELSE c END", "SELECT CASE WHEN a > 1 THEN b ELSE c END", ) - self.validate( - "SELECT IF(a > 1, 1) FROM foo", "SELECT CASE WHEN a > 1 THEN 1 END FROM foo" - ) + self.validate("SELECT IF(a > 1, 1) FROM foo", "SELECT CASE WHEN a > 1 THEN 1 END FROM foo") def test_ignore_nulls(self): self.validate("SELECT COUNT(x RESPECT NULLS)", "SELECT COUNT(x)") def test_time(self): self.validate("TIMESTAMP '2020-01-01'", "CAST('2020-01-01' AS TIMESTAMP)") - self.validate( - "TIMESTAMP WITH TIME ZONE '2020-01-01'", "CAST('2020-01-01' AS TIMESTAMPTZ)" - ) + self.validate("TIMESTAMP WITH TIME ZONE '2020-01-01'", "CAST('2020-01-01' AS TIMESTAMPTZ)") self.validate( "TIMESTAMP(9) WITH TIME ZONE '2020-01-01'", "CAST('2020-01-01' AS TIMESTAMPTZ(9))", @@ -159,9 +174,7 @@ class TestTranspile(unittest.TestCase): self.validate("DATE '2020-01-01'", "CAST('2020-01-01' AS DATE)") self.validate("'2020-01-01'::DATE", "CAST('2020-01-01' AS DATE)") self.validate("STR_TO_TIME('x', 'y')", "STRPTIME('x', 'y')", write="duckdb") - self.validate( - "STR_TO_UNIX('x', 'y')", "EPOCH(STRPTIME('x', 'y'))", write="duckdb" - ) + self.validate("STR_TO_UNIX('x', 'y')", "EPOCH(STRPTIME('x', 'y'))", write="duckdb") self.validate("TIME_TO_STR(x, 'y')", "STRFTIME(x, 'y')", write="duckdb") self.validate("TIME_TO_UNIX(x)", "EPOCH(x)", write="duckdb") self.validate( @@ -209,12 +222,8 @@ class TestTranspile(unittest.TestCase): self.validate("TIME_STR_TO_DATE(x)", "TIME_STR_TO_DATE(x)", write=None) self.validate("TIME_STR_TO_DATE(x)", "TO_DATE(x)", write="hive") - self.validate( - "UNIX_TO_STR(x, 'yyyy-MM-dd HH:mm:ss')", "FROM_UNIXTIME(x)", write="hive" - ) - self.validate( - "STR_TO_UNIX(x, 'yyyy-MM-dd HH:mm:ss')", "UNIX_TIMESTAMP(x)", write="hive" - ) + self.validate("UNIX_TO_STR(x, 'yyyy-MM-dd HH:mm:ss')", "FROM_UNIXTIME(x)", write="hive") + self.validate("STR_TO_UNIX(x, 'yyyy-MM-dd HH:mm:ss')", "UNIX_TIMESTAMP(x)", write="hive") self.validate("IF(x > 1, x + 1)", "IF(x > 1, x + 1)", write="presto") self.validate("IF(x > 1, 1 + 1)", "IF(x > 1, 1 + 1)", write="hive") self.validate("IF(x > 1, 1, 0)", "IF(x > 1, 1, 0)", write="hive") @@ -232,9 +241,7 @@ class TestTranspile(unittest.TestCase): ) self.validate("STR_TO_TIME('x', 'y')", "DATE_PARSE('x', 'y')", write="presto") - self.validate( - "STR_TO_UNIX('x', 'y')", "TO_UNIXTIME(DATE_PARSE('x', 'y'))", write="presto" - ) + self.validate("STR_TO_UNIX('x', 'y')", "TO_UNIXTIME(DATE_PARSE('x', 'y'))", write="presto") self.validate("TIME_TO_STR(x, 'y')", "DATE_FORMAT(x, 'y')", write="presto") self.validate("TIME_TO_UNIX(x)", "TO_UNIXTIME(x)", write="presto") self.validate( @@ -245,9 +252,7 @@ class TestTranspile(unittest.TestCase): self.validate("UNIX_TO_TIME(123)", "FROM_UNIXTIME(123)", write="presto") self.validate("STR_TO_TIME('x', 'y')", "TO_TIMESTAMP('x', 'y')", write="spark") - self.validate( - "STR_TO_UNIX('x', 'y')", "UNIX_TIMESTAMP('x', 'y')", write="spark" - ) + self.validate("STR_TO_UNIX('x', 'y')", "UNIX_TIMESTAMP('x', 'y')", write="spark") self.validate("TIME_TO_STR(x, 'y')", "DATE_FORMAT(x, 'y')", write="spark") self.validate( @@ -283,9 +288,7 @@ class TestTranspile(unittest.TestCase): def test_partial(self): for sql in load_sql_fixtures("partial.sql"): with self.subTest(sql): - self.assertEqual( - transpile(sql, error_level=ErrorLevel.IGNORE)[0], sql.strip() - ) + self.assertEqual(transpile(sql, error_level=ErrorLevel.IGNORE)[0], sql.strip()) def test_pretty(self): for _, sql, pretty in load_sql_fixture_pairs("pretty.sql"): |