summaryrefslogtreecommitdiffstats
path: root/tests/test_transpile.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_transpile.py')
-rw-r--r--tests/test_transpile.py51
1 files changed, 27 insertions, 24 deletions
diff --git a/tests/test_transpile.py b/tests/test_transpile.py
index 28bcc7a..4bec2ac 100644
--- a/tests/test_transpile.py
+++ b/tests/test_transpile.py
@@ -42,6 +42,20 @@ class TestTranspile(unittest.TestCase):
"SELECT * FROM x WHERE a = ANY (SELECT 1)",
)
+ def test_leading_comma(self):
+ self.validate(
+ "SELECT FOO, BAR, BAZ",
+ "SELECT\n FOO\n , BAR\n , BAZ",
+ leading_comma=True,
+ pretty=True,
+ )
+ # without pretty, this should be a no-op
+ self.validate(
+ "SELECT FOO, BAR, BAZ",
+ "SELECT FOO, BAR, BAZ",
+ leading_comma=True,
+ )
+
def test_space(self):
self.validate("SELECT MIN(3)>MIN(2)", "SELECT MIN(3) > MIN(2)")
self.validate("SELECT MIN(3)>=MIN(2)", "SELECT MIN(3) >= MIN(2)")
@@ -108,6 +122,11 @@ class TestTranspile(unittest.TestCase):
"extract(month from '2021-01-31'::timestamp without time zone)",
"EXTRACT(month FROM CAST('2021-01-31' AS TIMESTAMP))",
)
+ self.validate("extract(week from current_date + 2)", "EXTRACT(week FROM CURRENT_DATE + 2)")
+ self.validate(
+ "EXTRACT(minute FROM datetime1 - datetime2)",
+ "EXTRACT(minute FROM datetime1 - datetime2)",
+ )
def test_if(self):
self.validate(
@@ -122,18 +141,14 @@ class TestTranspile(unittest.TestCase):
"SELECT IF a > 1 THEN b ELSE c END",
"SELECT CASE WHEN a > 1 THEN b ELSE c END",
)
- self.validate(
- "SELECT IF(a > 1, 1) FROM foo", "SELECT CASE WHEN a > 1 THEN 1 END FROM foo"
- )
+ self.validate("SELECT IF(a > 1, 1) FROM foo", "SELECT CASE WHEN a > 1 THEN 1 END FROM foo")
def test_ignore_nulls(self):
self.validate("SELECT COUNT(x RESPECT NULLS)", "SELECT COUNT(x)")
def test_time(self):
self.validate("TIMESTAMP '2020-01-01'", "CAST('2020-01-01' AS TIMESTAMP)")
- self.validate(
- "TIMESTAMP WITH TIME ZONE '2020-01-01'", "CAST('2020-01-01' AS TIMESTAMPTZ)"
- )
+ self.validate("TIMESTAMP WITH TIME ZONE '2020-01-01'", "CAST('2020-01-01' AS TIMESTAMPTZ)")
self.validate(
"TIMESTAMP(9) WITH TIME ZONE '2020-01-01'",
"CAST('2020-01-01' AS TIMESTAMPTZ(9))",
@@ -159,9 +174,7 @@ class TestTranspile(unittest.TestCase):
self.validate("DATE '2020-01-01'", "CAST('2020-01-01' AS DATE)")
self.validate("'2020-01-01'::DATE", "CAST('2020-01-01' AS DATE)")
self.validate("STR_TO_TIME('x', 'y')", "STRPTIME('x', 'y')", write="duckdb")
- self.validate(
- "STR_TO_UNIX('x', 'y')", "EPOCH(STRPTIME('x', 'y'))", write="duckdb"
- )
+ self.validate("STR_TO_UNIX('x', 'y')", "EPOCH(STRPTIME('x', 'y'))", write="duckdb")
self.validate("TIME_TO_STR(x, 'y')", "STRFTIME(x, 'y')", write="duckdb")
self.validate("TIME_TO_UNIX(x)", "EPOCH(x)", write="duckdb")
self.validate(
@@ -209,12 +222,8 @@ class TestTranspile(unittest.TestCase):
self.validate("TIME_STR_TO_DATE(x)", "TIME_STR_TO_DATE(x)", write=None)
self.validate("TIME_STR_TO_DATE(x)", "TO_DATE(x)", write="hive")
- self.validate(
- "UNIX_TO_STR(x, 'yyyy-MM-dd HH:mm:ss')", "FROM_UNIXTIME(x)", write="hive"
- )
- self.validate(
- "STR_TO_UNIX(x, 'yyyy-MM-dd HH:mm:ss')", "UNIX_TIMESTAMP(x)", write="hive"
- )
+ self.validate("UNIX_TO_STR(x, 'yyyy-MM-dd HH:mm:ss')", "FROM_UNIXTIME(x)", write="hive")
+ self.validate("STR_TO_UNIX(x, 'yyyy-MM-dd HH:mm:ss')", "UNIX_TIMESTAMP(x)", write="hive")
self.validate("IF(x > 1, x + 1)", "IF(x > 1, x + 1)", write="presto")
self.validate("IF(x > 1, 1 + 1)", "IF(x > 1, 1 + 1)", write="hive")
self.validate("IF(x > 1, 1, 0)", "IF(x > 1, 1, 0)", write="hive")
@@ -232,9 +241,7 @@ class TestTranspile(unittest.TestCase):
)
self.validate("STR_TO_TIME('x', 'y')", "DATE_PARSE('x', 'y')", write="presto")
- self.validate(
- "STR_TO_UNIX('x', 'y')", "TO_UNIXTIME(DATE_PARSE('x', 'y'))", write="presto"
- )
+ self.validate("STR_TO_UNIX('x', 'y')", "TO_UNIXTIME(DATE_PARSE('x', 'y'))", write="presto")
self.validate("TIME_TO_STR(x, 'y')", "DATE_FORMAT(x, 'y')", write="presto")
self.validate("TIME_TO_UNIX(x)", "TO_UNIXTIME(x)", write="presto")
self.validate(
@@ -245,9 +252,7 @@ class TestTranspile(unittest.TestCase):
self.validate("UNIX_TO_TIME(123)", "FROM_UNIXTIME(123)", write="presto")
self.validate("STR_TO_TIME('x', 'y')", "TO_TIMESTAMP('x', 'y')", write="spark")
- self.validate(
- "STR_TO_UNIX('x', 'y')", "UNIX_TIMESTAMP('x', 'y')", write="spark"
- )
+ self.validate("STR_TO_UNIX('x', 'y')", "UNIX_TIMESTAMP('x', 'y')", write="spark")
self.validate("TIME_TO_STR(x, 'y')", "DATE_FORMAT(x, 'y')", write="spark")
self.validate(
@@ -283,9 +288,7 @@ class TestTranspile(unittest.TestCase):
def test_partial(self):
for sql in load_sql_fixtures("partial.sql"):
with self.subTest(sql):
- self.assertEqual(
- transpile(sql, error_level=ErrorLevel.IGNORE)[0], sql.strip()
- )
+ self.assertEqual(transpile(sql, error_level=ErrorLevel.IGNORE)[0], sql.strip())
def test_pretty(self):
for _, sql, pretty in load_sql_fixture_pairs("pretty.sql"):