summaryrefslogtreecommitdiffstats
path: root/tests/test_optimizer.py
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-10-09 06:28:52 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-10-09 06:28:52 +0000
commita8cfe41f416430cab0d6aa0ff6a688b2832a39aa (patch)
treedd48a1d853317a0daaaf3e2f6868e01dbad936e7 /tests/test_optimizer.py
parentReleasing debian version 25.24.0-1. (diff)
downloadsqlglot-a8cfe41f416430cab0d6aa0ff6a688b2832a39aa.tar.xz
sqlglot-a8cfe41f416430cab0d6aa0ff6a688b2832a39aa.zip
Merging upstream version 25.24.5.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'tests/test_optimizer.py')
-rw-r--r--tests/test_optimizer.py57
1 files changed, 34 insertions, 23 deletions
diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py
index 857ba1a..2c2015b 100644
--- a/tests/test_optimizer.py
+++ b/tests/test_optimizer.py
@@ -54,6 +54,18 @@ def simplify(expression, **kwargs):
return optimizer.simplify.simplify(expression, constant_propagation=True, **kwargs)
+def annotate_functions(expression, **kwargs):
+ from sqlglot.dialects import Dialect
+
+ dialect = kwargs.get("dialect")
+ schema = kwargs.get("schema")
+
+ annotators = Dialect.get_or_raise(dialect).ANNOTATORS
+ annotated = annotate_types(expression, annotators=annotators, schema=schema)
+
+ return annotated.expressions[0]
+
+
class TestOptimizer(unittest.TestCase):
maxDiff = None
@@ -787,6 +799,28 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
with self.subTest(title):
self.assertEqual(result.type.sql(), exp.DataType.build(expected).sql())
+ def test_annotate_funcs(self):
+ test_schema = {
+ "tbl": {"bin_col": "BINARY", "str_col": "STRING", "bignum_col": "BIGNUMERIC"}
+ }
+
+ for i, (meta, sql, expected) in enumerate(
+ load_sql_fixture_pairs("optimizer/annotate_functions.sql"), start=1
+ ):
+ title = meta.get("title") or f"{i}, {sql}"
+ dialect = meta.get("dialect") or ""
+ sql = f"SELECT {sql} FROM tbl"
+
+ for dialect in dialect.split(", "):
+ result = parse_and_optimize(
+ annotate_functions, sql, dialect, schema=test_schema, dialect=dialect
+ )
+
+ with self.subTest(title):
+ self.assertEqual(
+ result.type.sql(dialect), exp.DataType.build(expected).sql(dialect)
+ )
+
def test_cast_type_annotation(self):
expression = annotate_types(parse_one("CAST('2020-01-01' AS TIMESTAMPTZ(9))"))
self.assertEqual(expression.type.this, exp.DataType.Type.TIMESTAMPTZ)
@@ -1377,26 +1411,3 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
self.assertEqual(4, normalization_distance(gen_expr(2), max_=100))
self.assertEqual(18, normalization_distance(gen_expr(3), max_=100))
self.assertEqual(110, normalization_distance(gen_expr(10), max_=100))
-
- def test_custom_annotators(self):
- # In Spark hierarchy, SUBSTRING result type is dependent on input expr type
- for dialect in ("spark2", "spark", "databricks"):
- for expr_type_pair in (
- ("col", "STRING"),
- ("col", "BINARY"),
- ("'str_literal'", "STRING"),
- ("CAST('str_literal' AS BINARY)", "BINARY"),
- ):
- with self.subTest(
- f"Testing {dialect}'s SUBSTRING() result type for {expr_type_pair}"
- ):
- expr, type = expr_type_pair
- ast = parse_one(f"SELECT substring({expr}, 2, 3) AS x FROM tbl", read=dialect)
-
- subst_type = (
- optimizer.optimize(ast, schema={"tbl": {"col": type}}, dialect=dialect)
- .expressions[0]
- .type
- )
-
- self.assertEqual(subst_type.sql(dialect), exp.DataType.build(type).sql(dialect))