diff options
Diffstat (limited to 'tests/test_optimizer.py')
-rw-r--r-- | tests/test_optimizer.py | 66 |
1 files changed, 20 insertions, 46 deletions
diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index 0e8ce15..c0b362c 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -230,6 +230,17 @@ class TestOptimizer(unittest.TestCase): def test_qualify_columns(self, logger): self.assertEqual( optimizer.qualify_columns.qualify_columns( + parse_one( + "WITH RECURSIVE t AS (SELECT 1 AS x UNION ALL SELECT x + 1 FROM t AS child WHERE x < 10) SELECT * FROM t" + ), + schema={}, + infer_schema=False, + ).sql(), + "WITH RECURSIVE t AS (SELECT 1 AS x UNION ALL SELECT child.x + 1 AS _col_0 FROM t AS child WHERE child.x < 10) SELECT t.x AS x FROM t", + ) + + self.assertEqual( + optimizer.qualify_columns.qualify_columns( parse_one("WITH x AS (SELECT a FROM db.y) SELECT * FROM db.x"), schema={"db": {"x": {"z": "int"}, "y": {"a": "int"}}}, expand_stars=False, @@ -617,53 +628,16 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') level="warning", ) - def test_struct_type_annotation(self): - tests = { - ("SELECT STRUCT(1 AS col)", "spark"): "STRUCT<col INT>", - ("SELECT STRUCT(1 AS col, 2.5 AS row)", "spark"): "STRUCT<col INT, row DOUBLE>", - ("SELECT STRUCT(1)", "bigquery"): "STRUCT<INT>", - ( - "SELECT STRUCT(1 AS col, 2.5 AS row, struct(3.5 AS inner_col, 4 AS inner_row) AS nested_struct)", - "spark", - ): "STRUCT<col INT, row DOUBLE, nested_struct STRUCT<inner_col DOUBLE, inner_row INT>>", - ( - "SELECT STRUCT(1 AS col, 2.5, ARRAY[1, 2, 3] AS nested_array, 'foo')", - "bigquery", - ): "STRUCT<col INT, DOUBLE, nested_array ARRAY<INT>, VARCHAR>", - ("SELECT STRUCT(1, 2.5, 'bar')", "spark"): "STRUCT<INT, DOUBLE, VARCHAR>", - ('SELECT STRUCT(1 AS "CaseSensitive")', "spark"): 'STRUCT<"CaseSensitive" INT>', - ("SELECT STRUCT_PACK(a := 1, b := 2.5)", "duckdb"): "STRUCT<a INT, b DOUBLE>", - ("SELECT ROW(1, 2.5, 'foo')", "presto"): "STRUCT<INT, DOUBLE, VARCHAR>", - } - - for (sql, dialect), target_type in tests.items(): - with self.subTest(sql): - expression = annotate_types(parse_one(sql, read=dialect)) - assert expression.expressions[0].is_type(target_type) - - def test_literal_type_annotation(self): - tests = { - "SELECT 5": exp.DataType.Type.INT, - "SELECT 5.3": exp.DataType.Type.DOUBLE, - "SELECT 'bla'": exp.DataType.Type.VARCHAR, - "5": exp.DataType.Type.INT, - "5.3": exp.DataType.Type.DOUBLE, - "'bla'": exp.DataType.Type.VARCHAR, - } - - for sql, target_type in tests.items(): - expression = annotate_types(parse_one(sql)) - self.assertEqual(expression.find(exp.Literal).type.this, target_type) - - def test_boolean_type_annotation(self): - tests = { - "SELECT TRUE": exp.DataType.Type.BOOLEAN, - "FALSE": exp.DataType.Type.BOOLEAN, - } + def test_annotate_types(self): + for i, (meta, sql, expected) in enumerate( + load_sql_fixture_pairs("optimizer/annotate_types.sql"), start=1 + ): + title = meta.get("title") or f"{i}, {sql}" + dialect = meta.get("dialect") + result = parse_and_optimize(annotate_types, sql, dialect) - for sql, target_type in tests.items(): - expression = annotate_types(parse_one(sql)) - self.assertEqual(expression.find(exp.Boolean).type.this, target_type) + with self.subTest(title): + self.assertEqual(result.type.sql(), exp.DataType.build(expected).sql()) def test_cast_type_annotation(self): expression = annotate_types(parse_one("CAST('2020-01-01' AS TIMESTAMPTZ(9))")) |