summaryrefslogtreecommitdiffstats
path: root/tests/test_optimizer.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_optimizer.py')
-rw-r--r--tests/test_optimizer.py23
1 files changed, 23 insertions, 0 deletions
diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py
index d6e11a9..fe5a4d7 100644
--- a/tests/test_optimizer.py
+++ b/tests/test_optimizer.py
@@ -1345,3 +1345,26 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
self.assertEqual(4, normalization_distance(gen_expr(2), max_=100))
self.assertEqual(18, normalization_distance(gen_expr(3), max_=100))
self.assertEqual(110, normalization_distance(gen_expr(10), max_=100))
+
+ def test_custom_annotators(self):
+ # In Spark hierarchy, SUBSTRING result type is dependent on input expr type
+ for dialect in ("spark2", "spark", "databricks"):
+ for expr_type_pair in (
+ ("col", "STRING"),
+ ("col", "BINARY"),
+ ("'str_literal'", "STRING"),
+ ("CAST('str_literal' AS BINARY)", "BINARY"),
+ ):
+ with self.subTest(
+ f"Testing {dialect}'s SUBSTRING() result type for {expr_type_pair}"
+ ):
+ expr, type = expr_type_pair
+ ast = parse_one(f"SELECT substring({expr}, 2, 3) AS x FROM tbl", read=dialect)
+
+ subst_type = (
+ optimizer.optimize(ast, schema={"tbl": {"col": type}}, dialect=dialect)
+ .expressions[0]
+ .type
+ )
+
+ self.assertEqual(subst_type.sql(dialect), exp.DataType.build(type).sql(dialect))