diff options
Diffstat (limited to 'tests/test_optimizer.py')
-rw-r--r-- | tests/test_optimizer.py | 23 |
1 files changed, 23 insertions, 0 deletions
diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index d6e11a9..fe5a4d7 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -1345,3 +1345,26 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') self.assertEqual(4, normalization_distance(gen_expr(2), max_=100)) self.assertEqual(18, normalization_distance(gen_expr(3), max_=100)) self.assertEqual(110, normalization_distance(gen_expr(10), max_=100)) + + def test_custom_annotators(self): + # In Spark hierarchy, SUBSTRING result type is dependent on input expr type + for dialect in ("spark2", "spark", "databricks"): + for expr_type_pair in ( + ("col", "STRING"), + ("col", "BINARY"), + ("'str_literal'", "STRING"), + ("CAST('str_literal' AS BINARY)", "BINARY"), + ): + with self.subTest( + f"Testing {dialect}'s SUBSTRING() result type for {expr_type_pair}" + ): + expr, type = expr_type_pair + ast = parse_one(f"SELECT substring({expr}, 2, 3) AS x FROM tbl", read=dialect) + + subst_type = ( + optimizer.optimize(ast, schema={"tbl": {"col": type}}, dialect=dialect) + .expressions[0] + .type + ) + + self.assertEqual(subst_type.sql(dialect), exp.DataType.build(type).sql(dialect)) |