summaryrefslogtreecommitdiffstats
path: root/tests/test_optimizer.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_optimizer.py')
-rw-r--r--tests/test_optimizer.py66
1 files changed, 20 insertions, 46 deletions
diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py
index 0e8ce15..c0b362c 100644
--- a/tests/test_optimizer.py
+++ b/tests/test_optimizer.py
@@ -230,6 +230,17 @@ class TestOptimizer(unittest.TestCase):
def test_qualify_columns(self, logger):
self.assertEqual(
optimizer.qualify_columns.qualify_columns(
+ parse_one(
+ "WITH RECURSIVE t AS (SELECT 1 AS x UNION ALL SELECT x + 1 FROM t AS child WHERE x < 10) SELECT * FROM t"
+ ),
+ schema={},
+ infer_schema=False,
+ ).sql(),
+ "WITH RECURSIVE t AS (SELECT 1 AS x UNION ALL SELECT child.x + 1 AS _col_0 FROM t AS child WHERE child.x < 10) SELECT t.x AS x FROM t",
+ )
+
+ self.assertEqual(
+ optimizer.qualify_columns.qualify_columns(
parse_one("WITH x AS (SELECT a FROM db.y) SELECT * FROM db.x"),
schema={"db": {"x": {"z": "int"}, "y": {"a": "int"}}},
expand_stars=False,
@@ -617,53 +628,16 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
level="warning",
)
- def test_struct_type_annotation(self):
- tests = {
- ("SELECT STRUCT(1 AS col)", "spark"): "STRUCT<col INT>",
- ("SELECT STRUCT(1 AS col, 2.5 AS row)", "spark"): "STRUCT<col INT, row DOUBLE>",
- ("SELECT STRUCT(1)", "bigquery"): "STRUCT<INT>",
- (
- "SELECT STRUCT(1 AS col, 2.5 AS row, struct(3.5 AS inner_col, 4 AS inner_row) AS nested_struct)",
- "spark",
- ): "STRUCT<col INT, row DOUBLE, nested_struct STRUCT<inner_col DOUBLE, inner_row INT>>",
- (
- "SELECT STRUCT(1 AS col, 2.5, ARRAY[1, 2, 3] AS nested_array, 'foo')",
- "bigquery",
- ): "STRUCT<col INT, DOUBLE, nested_array ARRAY<INT>, VARCHAR>",
- ("SELECT STRUCT(1, 2.5, 'bar')", "spark"): "STRUCT<INT, DOUBLE, VARCHAR>",
- ('SELECT STRUCT(1 AS "CaseSensitive")', "spark"): 'STRUCT<"CaseSensitive" INT>',
- ("SELECT STRUCT_PACK(a := 1, b := 2.5)", "duckdb"): "STRUCT<a INT, b DOUBLE>",
- ("SELECT ROW(1, 2.5, 'foo')", "presto"): "STRUCT<INT, DOUBLE, VARCHAR>",
- }
-
- for (sql, dialect), target_type in tests.items():
- with self.subTest(sql):
- expression = annotate_types(parse_one(sql, read=dialect))
- assert expression.expressions[0].is_type(target_type)
-
- def test_literal_type_annotation(self):
- tests = {
- "SELECT 5": exp.DataType.Type.INT,
- "SELECT 5.3": exp.DataType.Type.DOUBLE,
- "SELECT 'bla'": exp.DataType.Type.VARCHAR,
- "5": exp.DataType.Type.INT,
- "5.3": exp.DataType.Type.DOUBLE,
- "'bla'": exp.DataType.Type.VARCHAR,
- }
-
- for sql, target_type in tests.items():
- expression = annotate_types(parse_one(sql))
- self.assertEqual(expression.find(exp.Literal).type.this, target_type)
-
- def test_boolean_type_annotation(self):
- tests = {
- "SELECT TRUE": exp.DataType.Type.BOOLEAN,
- "FALSE": exp.DataType.Type.BOOLEAN,
- }
+ def test_annotate_types(self):
+ for i, (meta, sql, expected) in enumerate(
+ load_sql_fixture_pairs("optimizer/annotate_types.sql"), start=1
+ ):
+ title = meta.get("title") or f"{i}, {sql}"
+ dialect = meta.get("dialect")
+ result = parse_and_optimize(annotate_types, sql, dialect)
- for sql, target_type in tests.items():
- expression = annotate_types(parse_one(sql))
- self.assertEqual(expression.find(exp.Boolean).type.this, target_type)
+ with self.subTest(title):
+ self.assertEqual(result.type.sql(), exp.DataType.build(expected).sql())
def test_cast_type_annotation(self):
expression = annotate_types(parse_one("CAST('2020-01-01' AS TIMESTAMPTZ(9))"))