diff options
Diffstat (limited to 'tests/test_optimizer.py')
-rw-r--r-- | tests/test_optimizer.py | 225 |
1 files changed, 198 insertions, 27 deletions
diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index aad84ed..36a7785 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -1,17 +1,55 @@ import unittest from functools import partial +import duckdb +from pandas.testing import assert_frame_equal + +import sqlglot from sqlglot import exp, optimizer, parse_one, table from sqlglot.errors import OptimizeError from sqlglot.optimizer.annotate_types import annotate_types from sqlglot.optimizer.schema import MappingSchema, ensure_schema from sqlglot.optimizer.scope import build_scope, traverse_scope, walk_in_scope -from tests.helpers import TPCH_SCHEMA, load_sql_fixture_pairs, load_sql_fixtures +from tests.helpers import ( + TPCH_SCHEMA, + load_sql_fixture_pairs, + load_sql_fixtures, + string_to_bool, +) class TestOptimizer(unittest.TestCase): maxDiff = None + @classmethod + def setUpClass(cls): + cls.conn = duckdb.connect() + cls.conn.execute( + """ + CREATE TABLE x (a INT, b INT); + CREATE TABLE y (b INT, c INT); + CREATE TABLE z (b INT, c INT); + + INSERT INTO x VALUES (1, 1); + INSERT INTO x VALUES (2, 2); + INSERT INTO x VALUES (2, 2); + INSERT INTO x VALUES (3, 3); + INSERT INTO x VALUES (null, null); + + INSERT INTO y VALUES (2, 2); + INSERT INTO y VALUES (2, 2); + INSERT INTO y VALUES (3, 3); + INSERT INTO y VALUES (4, 4); + INSERT INTO y VALUES (null, null); + + INSERT INTO y VALUES (3, 3); + INSERT INTO y VALUES (3, 3); + INSERT INTO y VALUES (4, 4); + INSERT INTO y VALUES (5, 5); + INSERT INTO y VALUES (null, null); + """ + ) + def setUp(self): self.schema = { "x": { @@ -28,29 +66,42 @@ class TestOptimizer(unittest.TestCase): }, } - def check_file(self, file, func, pretty=False, **kwargs): + def check_file(self, file, func, pretty=False, execute=False, **kwargs): for i, (meta, sql, expected) in enumerate(load_sql_fixture_pairs(f"optimizer/{file}.sql"), start=1): + title = meta.get("title") or f"{i}, {sql}" dialect = meta.get("dialect") leave_tables_isolated = meta.get("leave_tables_isolated") func_kwargs = {**kwargs} if leave_tables_isolated is not None: - func_kwargs["leave_tables_isolated"] = leave_tables_isolated.lower() in ("true", "1") + func_kwargs["leave_tables_isolated"] = string_to_bool(leave_tables_isolated) + + optimized = func(parse_one(sql, read=dialect), **func_kwargs) - with self.subTest(f"{i}, {sql}"): + with self.subTest(title): self.assertEqual( - func(parse_one(sql, read=dialect), **func_kwargs).sql(pretty=pretty, dialect=dialect), + optimized.sql(pretty=pretty, dialect=dialect), expected, ) + should_execute = meta.get("execute") + if should_execute is None: + should_execute = execute + + if string_to_bool(should_execute): + with self.subTest(f"(execute) {title}"): + df1 = self.conn.execute(sqlglot.transpile(sql, read=dialect, write="duckdb")[0]).df() + df2 = self.conn.execute(optimized.sql(pretty=pretty, dialect="duckdb")).df() + assert_frame_equal(df1, df2) + def test_optimize(self): schema = { "x": {"a": "INT", "b": "INT"}, - "y": {"a": "INT", "b": "INT"}, + "y": {"b": "INT", "c": "INT"}, "z": {"a": "INT", "c": "INT"}, } - self.check_file("optimizer", optimizer.optimize, pretty=True, schema=schema) + self.check_file("optimizer", optimizer.optimize, pretty=True, execute=True, schema=schema) def test_isolate_table_selects(self): self.check_file( @@ -86,7 +137,16 @@ class TestOptimizer(unittest.TestCase): expression = optimizer.qualify_columns.qualify_columns(expression, **kwargs) return expression - self.check_file("qualify_columns", qualify_columns, schema=self.schema) + self.check_file("qualify_columns", qualify_columns, execute=True, schema=self.schema) + + def test_qualify_columns__with_invisible(self): + def qualify_columns(expression, **kwargs): + expression = optimizer.qualify_tables.qualify_tables(expression) + expression = optimizer.qualify_columns.qualify_columns(expression, **kwargs) + return expression + + schema = MappingSchema(self.schema, {"x": {"a"}, "y": {"b"}, "z": {"b"}}) + self.check_file("qualify_columns__with_invisible", qualify_columns, schema=schema) def test_qualify_columns__invalid(self): for sql in load_sql_fixtures("optimizer/qualify_columns__invalid.sql"): @@ -141,7 +201,7 @@ class TestOptimizer(unittest.TestCase): ], ) - self.check_file("merge_subqueries", optimize, schema=self.schema) + self.check_file("merge_subqueries", optimize, execute=True, schema=self.schema) def test_eliminate_subqueries(self): self.check_file("eliminate_subqueries", optimizer.eliminate_subqueries.eliminate_subqueries) @@ -301,10 +361,8 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') } for sql, target_type in tests.items(): - expression = parse_one(sql) - annotated_expression = annotate_types(expression) - - self.assertEqual(annotated_expression.find(exp.Literal).type, target_type) + expression = annotate_types(parse_one(sql)) + self.assertEqual(expression.find(exp.Literal).type, target_type) def test_boolean_type_annotation(self): tests = { @@ -313,14 +371,11 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') } for sql, target_type in tests.items(): - expression = parse_one(sql) - annotated_expression = annotate_types(expression) - - self.assertEqual(annotated_expression.find(exp.Boolean).type, target_type) + expression = annotate_types(parse_one(sql)) + self.assertEqual(expression.find(exp.Boolean).type, target_type) def test_cast_type_annotation(self): - expression = parse_one("CAST('2020-01-01' AS TIMESTAMPTZ(9))") - annotate_types(expression) + expression = annotate_types(parse_one("CAST('2020-01-01' AS TIMESTAMPTZ(9))")) self.assertEqual(expression.type, exp.DataType.Type.TIMESTAMPTZ) self.assertEqual(expression.this.type, exp.DataType.Type.VARCHAR) @@ -328,16 +383,11 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') self.assertEqual(expression.args["to"].expressions[0].type, exp.DataType.Type.INT) def test_cache_annotation(self): - expression = parse_one("CACHE LAZY TABLE x OPTIONS('storageLevel' = 'value') AS SELECT 1") - annotated_expression = annotate_types(expression) - - self.assertEqual(annotated_expression.expression.expressions[0].type, exp.DataType.Type.INT) + expression = annotate_types(parse_one("CACHE LAZY TABLE x OPTIONS('storageLevel' = 'value') AS SELECT 1")) + self.assertEqual(expression.expression.expressions[0].type, exp.DataType.Type.INT) def test_binary_annotation(self): - expression = parse_one("SELECT 0.0 + (2 + 3)") - annotate_types(expression) - - expression = expression.expressions[0] + expression = annotate_types(parse_one("SELECT 0.0 + (2 + 3)")).expressions[0] self.assertEqual(expression.type, exp.DataType.Type.DOUBLE) self.assertEqual(expression.left.type, exp.DataType.Type.DOUBLE) @@ -345,3 +395,124 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') self.assertEqual(expression.right.this.type, exp.DataType.Type.INT) self.assertEqual(expression.right.this.left.type, exp.DataType.Type.INT) self.assertEqual(expression.right.this.right.type, exp.DataType.Type.INT) + + def test_derived_tables_column_annotation(self): + schema = {"x": {"cola": "INT"}, "y": {"cola": "FLOAT"}} + sql = """ + SELECT a.cola AS cola + FROM ( + SELECT x.cola + y.cola AS cola + FROM ( + SELECT x.cola AS cola + FROM x AS x + ) AS x + JOIN ( + SELECT y.cola AS cola + FROM y AS y + ) AS y + ) AS a + """ + + expression = annotate_types(parse_one(sql), schema=schema) + self.assertEqual(expression.expressions[0].type, exp.DataType.Type.FLOAT) # a.cola AS cola + + addition_alias = expression.args["from"].expressions[0].this.expressions[0] + self.assertEqual(addition_alias.type, exp.DataType.Type.FLOAT) # x.cola + y.cola AS cola + + addition = addition_alias.this + self.assertEqual(addition.type, exp.DataType.Type.FLOAT) + self.assertEqual(addition.this.type, exp.DataType.Type.INT) + self.assertEqual(addition.expression.type, exp.DataType.Type.FLOAT) + + def test_cte_column_annotation(self): + schema = {"x": {"cola": "CHAR"}, "y": {"colb": "TEXT"}} + sql = """ + WITH tbl AS ( + SELECT x.cola + 'bla' AS cola, y.colb AS colb + FROM ( + SELECT x.cola AS cola + FROM x AS x + ) AS x + JOIN ( + SELECT y.colb AS colb + FROM y AS y + ) AS y + ) + SELECT tbl.cola + tbl.colb + 'foo' AS col + FROM tbl AS tbl + """ + + expression = annotate_types(parse_one(sql), schema=schema) + self.assertEqual(expression.expressions[0].type, exp.DataType.Type.TEXT) # tbl.cola + tbl.colb + 'foo' AS col + + outer_addition = expression.expressions[0].this # (tbl.cola + tbl.colb) + 'foo' + self.assertEqual(outer_addition.type, exp.DataType.Type.TEXT) + self.assertEqual(outer_addition.left.type, exp.DataType.Type.TEXT) + self.assertEqual(outer_addition.right.type, exp.DataType.Type.VARCHAR) + + inner_addition = expression.expressions[0].this.left # tbl.cola + tbl.colb + self.assertEqual(inner_addition.left.type, exp.DataType.Type.VARCHAR) + self.assertEqual(inner_addition.right.type, exp.DataType.Type.TEXT) + + cte_select = expression.args["with"].expressions[0].this + self.assertEqual(cte_select.expressions[0].type, exp.DataType.Type.VARCHAR) # x.cola + 'bla' AS cola + self.assertEqual(cte_select.expressions[1].type, exp.DataType.Type.TEXT) # y.colb AS colb + + cte_select_addition = cte_select.expressions[0].this # x.cola + 'bla' + self.assertEqual(cte_select_addition.type, exp.DataType.Type.VARCHAR) + self.assertEqual(cte_select_addition.left.type, exp.DataType.Type.CHAR) + self.assertEqual(cte_select_addition.right.type, exp.DataType.Type.VARCHAR) + + # Check that x.cola AS cola and y.colb AS colb have types CHAR and TEXT, respectively + for d, t in zip(cte_select.find_all(exp.Subquery), [exp.DataType.Type.CHAR, exp.DataType.Type.TEXT]): + self.assertEqual(d.this.expressions[0].this.type, t) + + def test_function_annotation(self): + schema = {"x": {"cola": "VARCHAR", "colb": "CHAR"}} + sql = "SELECT x.cola || TRIM(x.colb) AS col FROM x AS x" + + concat_expr_alias = annotate_types(parse_one(sql), schema=schema).expressions[0] + self.assertEqual(concat_expr_alias.type, exp.DataType.Type.VARCHAR) + + concat_expr = concat_expr_alias.this + self.assertEqual(concat_expr.type, exp.DataType.Type.VARCHAR) + self.assertEqual(concat_expr.left.type, exp.DataType.Type.VARCHAR) # x.cola + self.assertEqual(concat_expr.right.type, exp.DataType.Type.VARCHAR) # TRIM(x.colb) + self.assertEqual(concat_expr.right.this.type, exp.DataType.Type.CHAR) # x.colb + + def test_unknown_annotation(self): + schema = {"x": {"cola": "VARCHAR"}} + sql = "SELECT x.cola || SOME_ANONYMOUS_FUNC(x.cola) AS col FROM x AS x" + + concat_expr_alias = annotate_types(parse_one(sql), schema=schema).expressions[0] + self.assertEqual(concat_expr_alias.type, exp.DataType.Type.UNKNOWN) + + concat_expr = concat_expr_alias.this + self.assertEqual(concat_expr.type, exp.DataType.Type.UNKNOWN) + self.assertEqual(concat_expr.left.type, exp.DataType.Type.VARCHAR) # x.cola + self.assertEqual(concat_expr.right.type, exp.DataType.Type.UNKNOWN) # SOME_ANONYMOUS_FUNC(x.cola) + self.assertEqual(concat_expr.right.expressions[0].type, exp.DataType.Type.VARCHAR) # x.cola (arg) + + def test_null_annotation(self): + expression = annotate_types(parse_one("SELECT NULL + 2 AS col")).expressions[0].this + self.assertEqual(expression.left.type, exp.DataType.Type.NULL) + self.assertEqual(expression.right.type, exp.DataType.Type.INT) + + # NULL <op> UNKNOWN should yield NULL + sql = "SELECT NULL || SOME_ANONYMOUS_FUNC() AS result" + + concat_expr_alias = annotate_types(parse_one(sql)).expressions[0] + self.assertEqual(concat_expr_alias.type, exp.DataType.Type.NULL) + + concat_expr = concat_expr_alias.this + self.assertEqual(concat_expr.type, exp.DataType.Type.NULL) + self.assertEqual(concat_expr.left.type, exp.DataType.Type.NULL) + self.assertEqual(concat_expr.right.type, exp.DataType.Type.UNKNOWN) + + def test_nullable_annotation(self): + nullable = exp.DataType.build("NULLABLE", expressions=exp.DataType.build("BOOLEAN")) + expression = annotate_types(parse_one("NULL AND FALSE")) + + self.assertEqual(expression.type, nullable) + self.assertEqual(expression.left.type, exp.DataType.Type.NULL) + self.assertEqual(expression.right.type, exp.DataType.Type.BOOLEAN) |