diff options
Diffstat (limited to 'tests/test_optimizer.py')
-rw-r--r-- | tests/test_optimizer.py | 82 |
1 files changed, 64 insertions, 18 deletions
diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index fd95577..141203d 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -103,6 +103,10 @@ class TestOptimizer(unittest.TestCase): "d": "TEXT", "e": "TEXT", }, + "temporal": { + "d": "DATE", + "t": "DATETIME", + }, } def check_file(self, file, func, pretty=False, execute=False, set_dialect=False, **kwargs): @@ -179,6 +183,18 @@ class TestOptimizer(unittest.TestCase): ) def test_qualify_tables(self): + self.assertEqual( + optimizer.qualify_tables.qualify_tables( + parse_one("select a from b"), catalog="catalog" + ).sql(), + "SELECT a FROM b AS b", + ) + + self.assertEqual( + optimizer.qualify_tables.qualify_tables(parse_one("select a from b"), db='"DB"').sql(), + 'SELECT a FROM "DB".b AS b', + ) + self.check_file( "qualify_tables", optimizer.qualify_tables.qualify_tables, @@ -282,6 +298,13 @@ class TestOptimizer(unittest.TestCase): self.assertEqual(optimizer.normalize_identifiers.normalize_identifiers("a%").sql(), '"a%"') + def test_quote_identifiers(self): + self.check_file( + "quote_identifiers", + optimizer.qualify_columns.quote_identifiers, + set_dialect=True, + ) + def test_pushdown_projection(self): self.check_file("pushdown_projections", pushdown_projections, schema=self.schema) @@ -300,8 +323,8 @@ class TestOptimizer(unittest.TestCase): safe_concat = parse_one("CONCAT('a', x, 'b', 'c')") simplified_safe_concat = optimizer.simplify.simplify(safe_concat) - self.assertIs(type(simplified_concat), exp.Concat) - self.assertIs(type(simplified_safe_concat), exp.SafeConcat) + self.assertEqual(simplified_concat.args["safe"], False) + self.assertEqual(simplified_safe_concat.args["safe"], True) self.assertEqual("CONCAT('a', x, 'bc')", simplified_concat.sql(dialect="presto")) self.assertEqual("CONCAT('a', x, 'bc')", simplified_safe_concat.sql()) @@ -561,6 +584,19 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') self.assertEqual(expression.right.this.left.type.this, exp.DataType.Type.INT) self.assertEqual(expression.right.this.right.type.this, exp.DataType.Type.INT) + for numeric_type in ("BIGINT", "DOUBLE", "INT"): + query = f"SELECT '1' + CAST(x AS {numeric_type})" + expression = annotate_types(parse_one(query)).expressions[0] + self.assertEqual(expression.type, exp.DataType.build(numeric_type)) + + def test_typeddiv_annotation(self): + expressions = annotate_types( + parse_one("SELECT 2 / 3, 2 / 3.0", dialect="presto") + ).expressions + + self.assertEqual(expressions[0].type.this, exp.DataType.Type.BIGINT) + self.assertEqual(expressions[1].type.this, exp.DataType.Type.DOUBLE) + def test_bracket_annotation(self): expression = annotate_types(parse_one("SELECT A[:]")).expressions[0] @@ -609,45 +645,60 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') "b": "DATETIME", } } - for sql, expected_type, *expected_sql in [ + for sql, expected_type in [ ( "SELECT '2023-01-01' + INTERVAL '1' DAY", exp.DataType.Type.DATE, - "SELECT CAST('2023-01-01' AS DATE) + INTERVAL '1' DAY", ), ( "SELECT '2023-01-01' + INTERVAL '1' HOUR", exp.DataType.Type.DATETIME, - "SELECT CAST('2023-01-01' AS DATETIME) + INTERVAL '1' HOUR", ), ( "SELECT '2023-01-01 00:00:01' + INTERVAL '1' HOUR", exp.DataType.Type.DATETIME, - "SELECT CAST('2023-01-01 00:00:01' AS DATETIME) + INTERVAL '1' HOUR", ), ("SELECT 'nonsense' + INTERVAL '1' DAY", exp.DataType.Type.UNKNOWN), ("SELECT x.a + INTERVAL '1' DAY FROM x AS x", exp.DataType.Type.DATE), - ("SELECT x.a + INTERVAL '1' HOUR FROM x AS x", exp.DataType.Type.DATETIME), + ( + "SELECT x.a + INTERVAL '1' HOUR FROM x AS x", + exp.DataType.Type.DATETIME, + ), ("SELECT x.b + INTERVAL '1' DAY FROM x AS x", exp.DataType.Type.DATETIME), ("SELECT x.b + INTERVAL '1' HOUR FROM x AS x", exp.DataType.Type.DATETIME), ( "SELECT DATE_ADD('2023-01-01', 1, 'DAY')", exp.DataType.Type.DATE, - "SELECT DATE_ADD(CAST('2023-01-01' AS DATE), 1, 'DAY')", ), ( "SELECT DATE_ADD('2023-01-01 00:00:00', 1, 'DAY')", exp.DataType.Type.DATETIME, - "SELECT DATE_ADD(CAST('2023-01-01 00:00:00' AS DATETIME), 1, 'DAY')", ), ("SELECT DATE_ADD(x.a, 1, 'DAY') FROM x AS x", exp.DataType.Type.DATE), - ("SELECT DATE_ADD(x.a, 1, 'HOUR') FROM x AS x", exp.DataType.Type.DATETIME), + ( + "SELECT DATE_ADD(x.a, 1, 'HOUR') FROM x AS x", + exp.DataType.Type.DATETIME, + ), ("SELECT DATE_ADD(x.b, 1, 'DAY') FROM x AS x", exp.DataType.Type.DATETIME), + ("SELECT DATE_TRUNC('DAY', x.a) FROM x AS x", exp.DataType.Type.DATE), + ("SELECT DATE_TRUNC('DAY', x.b) FROM x AS x", exp.DataType.Type.DATETIME), + ( + "SELECT DATE_TRUNC('SECOND', x.a) FROM x AS x", + exp.DataType.Type.DATETIME, + ), + ( + "SELECT DATE_TRUNC('DAY', '2023-01-01') FROM x AS x", + exp.DataType.Type.DATE, + ), + ( + "SELECT DATEDIFF('2023-01-01', '2023-01-02', DAY) FROM x AS x", + exp.DataType.Type.INT, + ), ]: with self.subTest(sql): expression = annotate_types(parse_one(sql), schema=schema) self.assertEqual(expected_type, expression.expressions[0].type.this) - self.assertEqual(expected_sql[0] if expected_sql else sql, expression.sql()) + self.assertEqual(sql, expression.sql()) def test_lateral_annotation(self): expression = optimizer.optimize( @@ -843,6 +894,8 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') ("MAX", "cold"): exp.DataType.Type.DATE, ("COUNT", "colb"): exp.DataType.Type.BIGINT, ("STDDEV", "cola"): exp.DataType.Type.DOUBLE, + ("ABS", "cola"): exp.DataType.Type.SMALLINT, + ("ABS", "colb"): exp.DataType.Type.FLOAT, } for (func, col), target_type in tests.items(): @@ -989,10 +1042,3 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') query = parse_one("select a.b:c from d", read="snowflake") qualified = optimizer.qualify.qualify(query) self.assertEqual(qualified.expressions[0].alias, "c") - - def test_qualify_tables_no_schema(self): - query = parse_one("select a from b") - self.assertEqual( - optimizer.qualify_tables.qualify_tables(query, catalog="catalog").sql(), - "SELECT a FROM b AS b", - ) |