summaryrefslogtreecommitdiffstats
path: root/tests/test_optimizer.py
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2022-10-10 11:29:00 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2022-10-10 11:29:00 +0000
commit74b38d30f43f7005428e09fa80508c5f21324c99 (patch)
tree7a0d4e49fffdc0330fc941c6528d3c8669a2acc6 /tests/test_optimizer.py
parentAdding upstream version 6.2.8. (diff)
downloadsqlglot-74b38d30f43f7005428e09fa80508c5f21324c99.tar.xz
sqlglot-74b38d30f43f7005428e09fa80508c5f21324c99.zip
Adding upstream version 6.3.1.upstream/6.3.1
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'tests/test_optimizer.py')
-rw-r--r--tests/test_optimizer.py225
1 files changed, 198 insertions, 27 deletions
diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py
index aad84ed..36a7785 100644
--- a/tests/test_optimizer.py
+++ b/tests/test_optimizer.py
@@ -1,17 +1,55 @@
import unittest
from functools import partial
+import duckdb
+from pandas.testing import assert_frame_equal
+
+import sqlglot
from sqlglot import exp, optimizer, parse_one, table
from sqlglot.errors import OptimizeError
from sqlglot.optimizer.annotate_types import annotate_types
from sqlglot.optimizer.schema import MappingSchema, ensure_schema
from sqlglot.optimizer.scope import build_scope, traverse_scope, walk_in_scope
-from tests.helpers import TPCH_SCHEMA, load_sql_fixture_pairs, load_sql_fixtures
+from tests.helpers import (
+ TPCH_SCHEMA,
+ load_sql_fixture_pairs,
+ load_sql_fixtures,
+ string_to_bool,
+)
class TestOptimizer(unittest.TestCase):
maxDiff = None
+ @classmethod
+ def setUpClass(cls):
+ cls.conn = duckdb.connect()
+ cls.conn.execute(
+ """
+ CREATE TABLE x (a INT, b INT);
+ CREATE TABLE y (b INT, c INT);
+ CREATE TABLE z (b INT, c INT);
+
+ INSERT INTO x VALUES (1, 1);
+ INSERT INTO x VALUES (2, 2);
+ INSERT INTO x VALUES (2, 2);
+ INSERT INTO x VALUES (3, 3);
+ INSERT INTO x VALUES (null, null);
+
+ INSERT INTO y VALUES (2, 2);
+ INSERT INTO y VALUES (2, 2);
+ INSERT INTO y VALUES (3, 3);
+ INSERT INTO y VALUES (4, 4);
+ INSERT INTO y VALUES (null, null);
+
+ INSERT INTO y VALUES (3, 3);
+ INSERT INTO y VALUES (3, 3);
+ INSERT INTO y VALUES (4, 4);
+ INSERT INTO y VALUES (5, 5);
+ INSERT INTO y VALUES (null, null);
+ """
+ )
+
def setUp(self):
self.schema = {
"x": {
@@ -28,29 +66,42 @@ class TestOptimizer(unittest.TestCase):
},
}
- def check_file(self, file, func, pretty=False, **kwargs):
+ def check_file(self, file, func, pretty=False, execute=False, **kwargs):
for i, (meta, sql, expected) in enumerate(load_sql_fixture_pairs(f"optimizer/{file}.sql"), start=1):
+ title = meta.get("title") or f"{i}, {sql}"
dialect = meta.get("dialect")
leave_tables_isolated = meta.get("leave_tables_isolated")
func_kwargs = {**kwargs}
if leave_tables_isolated is not None:
- func_kwargs["leave_tables_isolated"] = leave_tables_isolated.lower() in ("true", "1")
+ func_kwargs["leave_tables_isolated"] = string_to_bool(leave_tables_isolated)
+
+ optimized = func(parse_one(sql, read=dialect), **func_kwargs)
- with self.subTest(f"{i}, {sql}"):
+ with self.subTest(title):
self.assertEqual(
- func(parse_one(sql, read=dialect), **func_kwargs).sql(pretty=pretty, dialect=dialect),
+ optimized.sql(pretty=pretty, dialect=dialect),
expected,
)
+ should_execute = meta.get("execute")
+ if should_execute is None:
+ should_execute = execute
+
+ if string_to_bool(should_execute):
+ with self.subTest(f"(execute) {title}"):
+ df1 = self.conn.execute(sqlglot.transpile(sql, read=dialect, write="duckdb")[0]).df()
+ df2 = self.conn.execute(optimized.sql(pretty=pretty, dialect="duckdb")).df()
+ assert_frame_equal(df1, df2)
+
def test_optimize(self):
schema = {
"x": {"a": "INT", "b": "INT"},
- "y": {"a": "INT", "b": "INT"},
+ "y": {"b": "INT", "c": "INT"},
"z": {"a": "INT", "c": "INT"},
}
- self.check_file("optimizer", optimizer.optimize, pretty=True, schema=schema)
+ self.check_file("optimizer", optimizer.optimize, pretty=True, execute=True, schema=schema)
def test_isolate_table_selects(self):
self.check_file(
@@ -86,7 +137,16 @@ class TestOptimizer(unittest.TestCase):
expression = optimizer.qualify_columns.qualify_columns(expression, **kwargs)
return expression
- self.check_file("qualify_columns", qualify_columns, schema=self.schema)
+ self.check_file("qualify_columns", qualify_columns, execute=True, schema=self.schema)
+
+ def test_qualify_columns__with_invisible(self):
+ def qualify_columns(expression, **kwargs):
+ expression = optimizer.qualify_tables.qualify_tables(expression)
+ expression = optimizer.qualify_columns.qualify_columns(expression, **kwargs)
+ return expression
+
+ schema = MappingSchema(self.schema, {"x": {"a"}, "y": {"b"}, "z": {"b"}})
+ self.check_file("qualify_columns__with_invisible", qualify_columns, schema=schema)
def test_qualify_columns__invalid(self):
for sql in load_sql_fixtures("optimizer/qualify_columns__invalid.sql"):
@@ -141,7 +201,7 @@ class TestOptimizer(unittest.TestCase):
],
)
- self.check_file("merge_subqueries", optimize, schema=self.schema)
+ self.check_file("merge_subqueries", optimize, execute=True, schema=self.schema)
def test_eliminate_subqueries(self):
self.check_file("eliminate_subqueries", optimizer.eliminate_subqueries.eliminate_subqueries)
@@ -301,10 +361,8 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
}
for sql, target_type in tests.items():
- expression = parse_one(sql)
- annotated_expression = annotate_types(expression)
-
- self.assertEqual(annotated_expression.find(exp.Literal).type, target_type)
+ expression = annotate_types(parse_one(sql))
+ self.assertEqual(expression.find(exp.Literal).type, target_type)
def test_boolean_type_annotation(self):
tests = {
@@ -313,14 +371,11 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
}
for sql, target_type in tests.items():
- expression = parse_one(sql)
- annotated_expression = annotate_types(expression)
-
- self.assertEqual(annotated_expression.find(exp.Boolean).type, target_type)
+ expression = annotate_types(parse_one(sql))
+ self.assertEqual(expression.find(exp.Boolean).type, target_type)
def test_cast_type_annotation(self):
- expression = parse_one("CAST('2020-01-01' AS TIMESTAMPTZ(9))")
- annotate_types(expression)
+ expression = annotate_types(parse_one("CAST('2020-01-01' AS TIMESTAMPTZ(9))"))
self.assertEqual(expression.type, exp.DataType.Type.TIMESTAMPTZ)
self.assertEqual(expression.this.type, exp.DataType.Type.VARCHAR)
@@ -328,16 +383,11 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
self.assertEqual(expression.args["to"].expressions[0].type, exp.DataType.Type.INT)
def test_cache_annotation(self):
- expression = parse_one("CACHE LAZY TABLE x OPTIONS('storageLevel' = 'value') AS SELECT 1")
- annotated_expression = annotate_types(expression)
-
- self.assertEqual(annotated_expression.expression.expressions[0].type, exp.DataType.Type.INT)
+ expression = annotate_types(parse_one("CACHE LAZY TABLE x OPTIONS('storageLevel' = 'value') AS SELECT 1"))
+ self.assertEqual(expression.expression.expressions[0].type, exp.DataType.Type.INT)
def test_binary_annotation(self):
- expression = parse_one("SELECT 0.0 + (2 + 3)")
- annotate_types(expression)
-
- expression = expression.expressions[0]
+ expression = annotate_types(parse_one("SELECT 0.0 + (2 + 3)")).expressions[0]
self.assertEqual(expression.type, exp.DataType.Type.DOUBLE)
self.assertEqual(expression.left.type, exp.DataType.Type.DOUBLE)
@@ -345,3 +395,124 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
self.assertEqual(expression.right.this.type, exp.DataType.Type.INT)
self.assertEqual(expression.right.this.left.type, exp.DataType.Type.INT)
self.assertEqual(expression.right.this.right.type, exp.DataType.Type.INT)
+
+ def test_derived_tables_column_annotation(self):
+ schema = {"x": {"cola": "INT"}, "y": {"cola": "FLOAT"}}
+ sql = """
+ SELECT a.cola AS cola
+ FROM (
+ SELECT x.cola + y.cola AS cola
+ FROM (
+ SELECT x.cola AS cola
+ FROM x AS x
+ ) AS x
+ JOIN (
+ SELECT y.cola AS cola
+ FROM y AS y
+ ) AS y
+ ) AS a
+ """
+
+ expression = annotate_types(parse_one(sql), schema=schema)
+ self.assertEqual(expression.expressions[0].type, exp.DataType.Type.FLOAT) # a.cola AS cola
+
+ addition_alias = expression.args["from"].expressions[0].this.expressions[0]
+ self.assertEqual(addition_alias.type, exp.DataType.Type.FLOAT) # x.cola + y.cola AS cola
+
+ addition = addition_alias.this
+ self.assertEqual(addition.type, exp.DataType.Type.FLOAT)
+ self.assertEqual(addition.this.type, exp.DataType.Type.INT)
+ self.assertEqual(addition.expression.type, exp.DataType.Type.FLOAT)
+
+ def test_cte_column_annotation(self):
+ schema = {"x": {"cola": "CHAR"}, "y": {"colb": "TEXT"}}
+ sql = """
+ WITH tbl AS (
+ SELECT x.cola + 'bla' AS cola, y.colb AS colb
+ FROM (
+ SELECT x.cola AS cola
+ FROM x AS x
+ ) AS x
+ JOIN (
+ SELECT y.colb AS colb
+ FROM y AS y
+ ) AS y
+ )
+ SELECT tbl.cola + tbl.colb + 'foo' AS col
+ FROM tbl AS tbl
+ """
+
+ expression = annotate_types(parse_one(sql), schema=schema)
+ self.assertEqual(expression.expressions[0].type, exp.DataType.Type.TEXT) # tbl.cola + tbl.colb + 'foo' AS col
+
+ outer_addition = expression.expressions[0].this # (tbl.cola + tbl.colb) + 'foo'
+ self.assertEqual(outer_addition.type, exp.DataType.Type.TEXT)
+ self.assertEqual(outer_addition.left.type, exp.DataType.Type.TEXT)
+ self.assertEqual(outer_addition.right.type, exp.DataType.Type.VARCHAR)
+
+ inner_addition = expression.expressions[0].this.left # tbl.cola + tbl.colb
+ self.assertEqual(inner_addition.left.type, exp.DataType.Type.VARCHAR)
+ self.assertEqual(inner_addition.right.type, exp.DataType.Type.TEXT)
+
+ cte_select = expression.args["with"].expressions[0].this
+ self.assertEqual(cte_select.expressions[0].type, exp.DataType.Type.VARCHAR) # x.cola + 'bla' AS cola
+ self.assertEqual(cte_select.expressions[1].type, exp.DataType.Type.TEXT) # y.colb AS colb
+
+ cte_select_addition = cte_select.expressions[0].this # x.cola + 'bla'
+ self.assertEqual(cte_select_addition.type, exp.DataType.Type.VARCHAR)
+ self.assertEqual(cte_select_addition.left.type, exp.DataType.Type.CHAR)
+ self.assertEqual(cte_select_addition.right.type, exp.DataType.Type.VARCHAR)
+
+ # Check that x.cola AS cola and y.colb AS colb have types CHAR and TEXT, respectively
+ for d, t in zip(cte_select.find_all(exp.Subquery), [exp.DataType.Type.CHAR, exp.DataType.Type.TEXT]):
+ self.assertEqual(d.this.expressions[0].this.type, t)
+
+ def test_function_annotation(self):
+ schema = {"x": {"cola": "VARCHAR", "colb": "CHAR"}}
+ sql = "SELECT x.cola || TRIM(x.colb) AS col FROM x AS x"
+
+ concat_expr_alias = annotate_types(parse_one(sql), schema=schema).expressions[0]
+ self.assertEqual(concat_expr_alias.type, exp.DataType.Type.VARCHAR)
+
+ concat_expr = concat_expr_alias.this
+ self.assertEqual(concat_expr.type, exp.DataType.Type.VARCHAR)
+ self.assertEqual(concat_expr.left.type, exp.DataType.Type.VARCHAR) # x.cola
+ self.assertEqual(concat_expr.right.type, exp.DataType.Type.VARCHAR) # TRIM(x.colb)
+ self.assertEqual(concat_expr.right.this.type, exp.DataType.Type.CHAR) # x.colb
+
+ def test_unknown_annotation(self):
+ schema = {"x": {"cola": "VARCHAR"}}
+ sql = "SELECT x.cola || SOME_ANONYMOUS_FUNC(x.cola) AS col FROM x AS x"
+
+ concat_expr_alias = annotate_types(parse_one(sql), schema=schema).expressions[0]
+ self.assertEqual(concat_expr_alias.type, exp.DataType.Type.UNKNOWN)
+
+ concat_expr = concat_expr_alias.this
+ self.assertEqual(concat_expr.type, exp.DataType.Type.UNKNOWN)
+ self.assertEqual(concat_expr.left.type, exp.DataType.Type.VARCHAR) # x.cola
+ self.assertEqual(concat_expr.right.type, exp.DataType.Type.UNKNOWN) # SOME_ANONYMOUS_FUNC(x.cola)
+ self.assertEqual(concat_expr.right.expressions[0].type, exp.DataType.Type.VARCHAR) # x.cola (arg)
+
+ def test_null_annotation(self):
+ expression = annotate_types(parse_one("SELECT NULL + 2 AS col")).expressions[0].this
+ self.assertEqual(expression.left.type, exp.DataType.Type.NULL)
+ self.assertEqual(expression.right.type, exp.DataType.Type.INT)
+
+ # NULL <op> UNKNOWN should yield NULL
+ sql = "SELECT NULL || SOME_ANONYMOUS_FUNC() AS result"
+
+ concat_expr_alias = annotate_types(parse_one(sql)).expressions[0]
+ self.assertEqual(concat_expr_alias.type, exp.DataType.Type.NULL)
+
+ concat_expr = concat_expr_alias.this
+ self.assertEqual(concat_expr.type, exp.DataType.Type.NULL)
+ self.assertEqual(concat_expr.left.type, exp.DataType.Type.NULL)
+ self.assertEqual(concat_expr.right.type, exp.DataType.Type.UNKNOWN)
+
+ def test_nullable_annotation(self):
+ nullable = exp.DataType.build("NULLABLE", expressions=exp.DataType.build("BOOLEAN"))
+ expression = annotate_types(parse_one("NULL AND FALSE"))
+
+ self.assertEqual(expression.type, nullable)
+ self.assertEqual(expression.left.type, exp.DataType.Type.NULL)
+ self.assertEqual(expression.right.type, exp.DataType.Type.BOOLEAN)