diff options
Diffstat (limited to 'tests/test_optimizer.py')
-rw-r--r-- | tests/test_optimizer.py | 54 |
1 files changed, 44 insertions, 10 deletions
diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index 3b5990f..a1b7e70 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -67,7 +67,9 @@ class TestOptimizer(unittest.TestCase): } def check_file(self, file, func, pretty=False, execute=False, **kwargs): - for i, (meta, sql, expected) in enumerate(load_sql_fixture_pairs(f"optimizer/{file}.sql"), start=1): + for i, (meta, sql, expected) in enumerate( + load_sql_fixture_pairs(f"optimizer/{file}.sql"), start=1 + ): title = meta.get("title") or f"{i}, {sql}" dialect = meta.get("dialect") leave_tables_isolated = meta.get("leave_tables_isolated") @@ -90,7 +92,9 @@ class TestOptimizer(unittest.TestCase): if string_to_bool(should_execute): with self.subTest(f"(execute) {title}"): - df1 = self.conn.execute(sqlglot.transpile(sql, read=dialect, write="duckdb")[0]).df() + df1 = self.conn.execute( + sqlglot.transpile(sql, read=dialect, write="duckdb")[0] + ).df() df2 = self.conn.execute(optimized.sql(pretty=pretty, dialect="duckdb")).df() assert_frame_equal(df1, df2) @@ -268,7 +272,8 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') self.assertEqual(scopes[1].expression.sql(), "SELECT y.b FROM y") self.assertEqual(scopes[2].expression.sql(), "(VALUES (1, 'test')) AS tab(cola, colb)") self.assertEqual( - scopes[3].expression.sql(), "SELECT cola, colb FROM (VALUES (1, 'test')) AS tab(cola, colb)" + scopes[3].expression.sql(), + "SELECT cola, colb FROM (VALUES (1, 'test')) AS tab(cola, colb)", ) self.assertEqual(scopes[4].expression.sql(), "SELECT y.c AS b FROM y") self.assertEqual(scopes[5].expression.sql(), "SELECT MAX(x.a) FROM x WHERE x.b = s.b") @@ -287,7 +292,11 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') # Check that we can walk in scope from an arbitrary node self.assertEqual( - {node.sql() for node, *_ in walk_in_scope(expression.find(exp.Where)) if isinstance(node, exp.Column)}, + { + node.sql() + for node, *_ in walk_in_scope(expression.find(exp.Where)) + if isinstance(node, exp.Column) + }, {"s.b"}, ) @@ -324,7 +333,9 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') self.assertEqual(expression.args["to"].expressions[0].type, exp.DataType.Type.INT) def test_cache_annotation(self): - expression = annotate_types(parse_one("CACHE LAZY TABLE x OPTIONS('storageLevel' = 'value') AS SELECT 1")) + expression = annotate_types( + parse_one("CACHE LAZY TABLE x OPTIONS('storageLevel' = 'value') AS SELECT 1") + ) self.assertEqual(expression.expression.expressions[0].type, exp.DataType.Type.INT) def test_binary_annotation(self): @@ -384,7 +395,9 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') """ expression = annotate_types(parse_one(sql), schema=schema) - self.assertEqual(expression.expressions[0].type, exp.DataType.Type.TEXT) # tbl.cola + tbl.colb + 'foo' AS col + self.assertEqual( + expression.expressions[0].type, exp.DataType.Type.TEXT + ) # tbl.cola + tbl.colb + 'foo' AS col outer_addition = expression.expressions[0].this # (tbl.cola + tbl.colb) + 'foo' self.assertEqual(outer_addition.type, exp.DataType.Type.TEXT) @@ -396,7 +409,9 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') self.assertEqual(inner_addition.right.type, exp.DataType.Type.TEXT) cte_select = expression.args["with"].expressions[0].this - self.assertEqual(cte_select.expressions[0].type, exp.DataType.Type.VARCHAR) # x.cola + 'bla' AS cola + self.assertEqual( + cte_select.expressions[0].type, exp.DataType.Type.VARCHAR + ) # x.cola + 'bla' AS cola self.assertEqual(cte_select.expressions[1].type, exp.DataType.Type.TEXT) # y.colb AS colb cte_select_addition = cte_select.expressions[0].this # x.cola + 'bla' @@ -405,7 +420,9 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') self.assertEqual(cte_select_addition.right.type, exp.DataType.Type.VARCHAR) # Check that x.cola AS cola and y.colb AS colb have types CHAR and TEXT, respectively - for d, t in zip(cte_select.find_all(exp.Subquery), [exp.DataType.Type.CHAR, exp.DataType.Type.TEXT]): + for d, t in zip( + cte_select.find_all(exp.Subquery), [exp.DataType.Type.CHAR, exp.DataType.Type.TEXT] + ): self.assertEqual(d.this.expressions[0].this.type, t) def test_function_annotation(self): @@ -421,6 +438,19 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') self.assertEqual(concat_expr.right.type, exp.DataType.Type.VARCHAR) # TRIM(x.colb) self.assertEqual(concat_expr.right.this.type, exp.DataType.Type.CHAR) # x.colb + sql = "SELECT CASE WHEN 1=1 THEN x.cola ELSE x.colb END AS col FROM x AS x" + + case_expr_alias = annotate_types(parse_one(sql), schema=schema).expressions[0] + self.assertEqual(case_expr_alias.type, exp.DataType.Type.VARCHAR) + + case_expr = case_expr_alias.this + self.assertEqual(case_expr.type, exp.DataType.Type.VARCHAR) + self.assertEqual(case_expr.args["default"].type, exp.DataType.Type.CHAR) + + case_ifs_expr = case_expr.args["ifs"][0] + self.assertEqual(case_ifs_expr.type, exp.DataType.Type.VARCHAR) + self.assertEqual(case_ifs_expr.args["true"].type, exp.DataType.Type.VARCHAR) + def test_unknown_annotation(self): schema = {"x": {"cola": "VARCHAR"}} sql = "SELECT x.cola || SOME_ANONYMOUS_FUNC(x.cola) AS col FROM x AS x" @@ -431,8 +461,12 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') concat_expr = concat_expr_alias.this self.assertEqual(concat_expr.type, exp.DataType.Type.UNKNOWN) self.assertEqual(concat_expr.left.type, exp.DataType.Type.VARCHAR) # x.cola - self.assertEqual(concat_expr.right.type, exp.DataType.Type.UNKNOWN) # SOME_ANONYMOUS_FUNC(x.cola) - self.assertEqual(concat_expr.right.expressions[0].type, exp.DataType.Type.VARCHAR) # x.cola (arg) + self.assertEqual( + concat_expr.right.type, exp.DataType.Type.UNKNOWN + ) # SOME_ANONYMOUS_FUNC(x.cola) + self.assertEqual( + concat_expr.right.expressions[0].type, exp.DataType.Type.VARCHAR + ) # x.cola (arg) def test_null_annotation(self): expression = annotate_types(parse_one("SELECT NULL + 2 AS col")).expressions[0].this |