diff options
author | Daniel Baumann <mail@daniel-baumann.ch> | 2023-12-10 10:45:55 +0000 |
---|---|---|
committer | Daniel Baumann <mail@daniel-baumann.ch> | 2023-12-10 10:45:55 +0000 |
commit | 02df6cdb000c8dbf739abda2af321a4f90d1b059 (patch) | |
tree | 2fc1daf848082ff67a11e60025cac260e3c318b2 /tests/test_transpile.py | |
parent | Adding upstream version 19.0.1. (diff) | |
download | sqlglot-02df6cdb000c8dbf739abda2af321a4f90d1b059.tar.xz sqlglot-02df6cdb000c8dbf739abda2af321a4f90d1b059.zip |
Adding upstream version 20.1.0.upstream/20.1.0
Signed-off-by: Daniel Baumann <mail@daniel-baumann.ch>
Diffstat (limited to 'tests/test_transpile.py')
-rw-r--r-- | tests/test_transpile.py | 75 |
1 files changed, 63 insertions, 12 deletions
diff --git a/tests/test_transpile.py b/tests/test_transpile.py index c16b1f6..b732b45 100644 --- a/tests/test_transpile.py +++ b/tests/test_transpile.py @@ -4,6 +4,7 @@ from unittest import mock from sqlglot import parse_one, transpile from sqlglot.errors import ErrorLevel, ParseError, UnsupportedError +from sqlglot.helper import logger as helper_logger from tests.helpers import ( assert_logger_contains, load_sql_fixture_pairs, @@ -91,6 +92,10 @@ class TestTranspile(unittest.TestCase): def test_comments(self): self.validate( + "SELECT c AS /* foo */ (a, b, c) FROM t", + "SELECT c AS (a, b, c) /* foo */ FROM t", + ) + self.validate( "SELECT * FROM t1\n/*x*/\nUNION ALL SELECT * FROM t2", "SELECT * FROM t1 /* x */ UNION ALL SELECT * FROM t2", ) @@ -434,6 +439,40 @@ SELECT FROM dw_1_dw_1_1.exactonline_2.transactionlines""", pretty=True, ) + self.validate( + """/* The result of some calculations + */ +with + base as ( + select + sum(sb.hep_amount) as hep_amount, + -- I AM REMOVED + sum(sb.hep_budget) + /* Budget defined in sharepoint */ + as blub + , 1 as bla + from gold.data_budget sb + group by all + ) +select + * +from base +""", + """/* The result of some calculations + */ +WITH base AS ( + SELECT + SUM(sb.hep_amount) AS hep_amount, + SUM(sb.hep_budget) /* I AM REMOVED */ AS blub, /* Budget defined in sharepoint */ + 1 AS bla + FROM gold.data_budget AS sb + GROUP BY ALL +) +SELECT + * +FROM base""", + pretty=True, + ) def test_types(self): self.validate("INT 1", "CAST(1 AS INT)") @@ -661,19 +700,27 @@ FROM dw_1_dw_1_1.exactonline_2.transactionlines""", write="spark2", ) - @mock.patch("sqlglot.helper.logger") - def test_index_offset(self, logger): - self.validate("x[0]", "x[1]", write="presto", identity=False) - self.validate("x[1]", "x[0]", read="presto", identity=False) - logger.warning.assert_any_call("Applying array index offset (%s)", 1) - logger.warning.assert_any_call("Applying array index offset (%s)", -1) + def test_index_offset(self): + with self.assertLogs(helper_logger) as cm: + self.validate("x[0]", "x[1]", write="presto", identity=False) + self.validate("x[1]", "x[0]", read="presto", identity=False) - self.validate("x[x - 1]", "x[x - 1]", write="presto", identity=False) - self.validate( - "x[array_size(y) - 1]", "x[CARDINALITY(y) - 1 + 1]", write="presto", identity=False - ) - self.validate("x[3 - 1]", "x[3]", write="presto", identity=False) - self.validate("MAP(a, b)[0]", "MAP(a, b)[0]", write="presto", identity=False) + self.validate("x[x - 1]", "x[x - 1]", write="presto", identity=False) + self.validate( + "x[array_size(y) - 1]", "x[CARDINALITY(y) - 1 + 1]", write="presto", identity=False + ) + self.validate("x[3 - 1]", "x[3]", write="presto", identity=False) + self.validate("MAP(a, b)[0]", "MAP(a, b)[0]", write="presto", identity=False) + + self.assertEqual( + cm.output, + [ + "WARNING:sqlglot:Applying array index offset (1)", + "WARNING:sqlglot:Applying array index offset (-1)", + "WARNING:sqlglot:Applying array index offset (1)", + "WARNING:sqlglot:Applying array index offset (1)", + ], + ) def test_identify_lambda(self): self.validate("x(y -> y)", 'X("y" -> "y")', identify=True) @@ -706,6 +753,10 @@ FROM dw_1_dw_1_1.exactonline_2.transactionlines""", def test_pretty_line_breaks(self): self.assertEqual(transpile("SELECT '1\n2'", pretty=True)[0], "SELECT\n '1\n2'") + self.assertEqual( + transpile("SELECT '1\n2'", pretty=True, unsupported_level=ErrorLevel.IGNORE)[0], + "SELECT\n '1\n2'", + ) @mock.patch("sqlglot.parser.logger") def test_error_level(self, logger): |