diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-10-09 06:28:52 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-10-09 06:28:52 +0000 |
commit | a8cfe41f416430cab0d6aa0ff6a688b2832a39aa (patch) | |
tree | dd48a1d853317a0daaaf3e2f6868e01dbad936e7 /tests/test_optimizer.py | |
parent | Releasing debian version 25.24.0-1. (diff) | |
download | sqlglot-a8cfe41f416430cab0d6aa0ff6a688b2832a39aa.tar.xz sqlglot-a8cfe41f416430cab0d6aa0ff6a688b2832a39aa.zip |
Merging upstream version 25.24.5.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'tests/test_optimizer.py')
-rw-r--r-- | tests/test_optimizer.py | 57 |
1 files changed, 34 insertions, 23 deletions
diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index 857ba1a..2c2015b 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -54,6 +54,18 @@ def simplify(expression, **kwargs): return optimizer.simplify.simplify(expression, constant_propagation=True, **kwargs) +def annotate_functions(expression, **kwargs): + from sqlglot.dialects import Dialect + + dialect = kwargs.get("dialect") + schema = kwargs.get("schema") + + annotators = Dialect.get_or_raise(dialect).ANNOTATORS + annotated = annotate_types(expression, annotators=annotators, schema=schema) + + return annotated.expressions[0] + + class TestOptimizer(unittest.TestCase): maxDiff = None @@ -787,6 +799,28 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') with self.subTest(title): self.assertEqual(result.type.sql(), exp.DataType.build(expected).sql()) + def test_annotate_funcs(self): + test_schema = { + "tbl": {"bin_col": "BINARY", "str_col": "STRING", "bignum_col": "BIGNUMERIC"} + } + + for i, (meta, sql, expected) in enumerate( + load_sql_fixture_pairs("optimizer/annotate_functions.sql"), start=1 + ): + title = meta.get("title") or f"{i}, {sql}" + dialect = meta.get("dialect") or "" + sql = f"SELECT {sql} FROM tbl" + + for dialect in dialect.split(", "): + result = parse_and_optimize( + annotate_functions, sql, dialect, schema=test_schema, dialect=dialect + ) + + with self.subTest(title): + self.assertEqual( + result.type.sql(dialect), exp.DataType.build(expected).sql(dialect) + ) + def test_cast_type_annotation(self): expression = annotate_types(parse_one("CAST('2020-01-01' AS TIMESTAMPTZ(9))")) self.assertEqual(expression.type.this, exp.DataType.Type.TIMESTAMPTZ) @@ -1377,26 +1411,3 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') self.assertEqual(4, normalization_distance(gen_expr(2), max_=100)) self.assertEqual(18, normalization_distance(gen_expr(3), max_=100)) self.assertEqual(110, normalization_distance(gen_expr(10), max_=100)) - - def test_custom_annotators(self): - # In Spark hierarchy, SUBSTRING result type is dependent on input expr type - for dialect in ("spark2", "spark", "databricks"): - for expr_type_pair in ( - ("col", "STRING"), - ("col", "BINARY"), - ("'str_literal'", "STRING"), - ("CAST('str_literal' AS BINARY)", "BINARY"), - ): - with self.subTest( - f"Testing {dialect}'s SUBSTRING() result type for {expr_type_pair}" - ): - expr, type = expr_type_pair - ast = parse_one(f"SELECT substring({expr}, 2, 3) AS x FROM tbl", read=dialect) - - subst_type = ( - optimizer.optimize(ast, schema={"tbl": {"col": type}}, dialect=dialect) - .expressions[0] - .type - ) - - self.assertEqual(subst_type.sql(dialect), exp.DataType.build(type).sql(dialect)) |