From e894f65cf8a2e3c88439e1b06d8542b969e2bc3f Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Wed, 11 Sep 2024 14:13:38 +0200 Subject: Merging upstream version 25.20.1. Signed-off-by: Daniel Baumann --- tests/test_optimizer.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) (limited to 'tests/test_optimizer.py') diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index d6e11a9..fe5a4d7 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -1345,3 +1345,26 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') self.assertEqual(4, normalization_distance(gen_expr(2), max_=100)) self.assertEqual(18, normalization_distance(gen_expr(3), max_=100)) self.assertEqual(110, normalization_distance(gen_expr(10), max_=100)) + + def test_custom_annotators(self): + # In Spark hierarchy, SUBSTRING result type is dependent on input expr type + for dialect in ("spark2", "spark", "databricks"): + for expr_type_pair in ( + ("col", "STRING"), + ("col", "BINARY"), + ("'str_literal'", "STRING"), + ("CAST('str_literal' AS BINARY)", "BINARY"), + ): + with self.subTest( + f"Testing {dialect}'s SUBSTRING() result type for {expr_type_pair}" + ): + expr, type = expr_type_pair + ast = parse_one(f"SELECT substring({expr}, 2, 3) AS x FROM tbl", read=dialect) + + subst_type = ( + optimizer.optimize(ast, schema={"tbl": {"col": type}}, dialect=dialect) + .expressions[0] + .type + ) + + self.assertEqual(subst_type.sql(dialect), exp.DataType.build(type).sql(dialect)) -- cgit v1.2.3