summaryrefslogtreecommitdiffstats
path: root/tests/test_optimizer.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_optimizer.py')
-rw-r--r--tests/test_optimizer.py155
1 files changed, 99 insertions, 56 deletions
diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py
index ecf581d..0c5f6cd 100644
--- a/tests/test_optimizer.py
+++ b/tests/test_optimizer.py
@@ -333,7 +333,7 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
for sql, target_type in tests.items():
expression = annotate_types(parse_one(sql))
- self.assertEqual(expression.find(exp.Literal).type, target_type)
+ self.assertEqual(expression.find(exp.Literal).type.this, target_type)
def test_boolean_type_annotation(self):
tests = {
@@ -343,31 +343,33 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
for sql, target_type in tests.items():
expression = annotate_types(parse_one(sql))
- self.assertEqual(expression.find(exp.Boolean).type, target_type)
+ self.assertEqual(expression.find(exp.Boolean).type.this, target_type)
def test_cast_type_annotation(self):
expression = annotate_types(parse_one("CAST('2020-01-01' AS TIMESTAMPTZ(9))"))
+ self.assertEqual(expression.type.this, exp.DataType.Type.TIMESTAMPTZ)
+ self.assertEqual(expression.this.type.this, exp.DataType.Type.VARCHAR)
+ self.assertEqual(expression.args["to"].type.this, exp.DataType.Type.TIMESTAMPTZ)
+ self.assertEqual(expression.args["to"].expressions[0].type.this, exp.DataType.Type.INT)
- self.assertEqual(expression.type, exp.DataType.Type.TIMESTAMPTZ)
- self.assertEqual(expression.this.type, exp.DataType.Type.VARCHAR)
- self.assertEqual(expression.args["to"].type, exp.DataType.Type.TIMESTAMPTZ)
- self.assertEqual(expression.args["to"].expressions[0].type, exp.DataType.Type.INT)
+ expression = annotate_types(parse_one("ARRAY(1)::ARRAY<INT>"))
+ self.assertEqual(expression.type, parse_one("ARRAY<INT>", into=exp.DataType))
def test_cache_annotation(self):
expression = annotate_types(
parse_one("CACHE LAZY TABLE x OPTIONS('storageLevel' = 'value') AS SELECT 1")
)
- self.assertEqual(expression.expression.expressions[0].type, exp.DataType.Type.INT)
+ self.assertEqual(expression.expression.expressions[0].type.this, exp.DataType.Type.INT)
def test_binary_annotation(self):
expression = annotate_types(parse_one("SELECT 0.0 + (2 + 3)")).expressions[0]
- self.assertEqual(expression.type, exp.DataType.Type.DOUBLE)
- self.assertEqual(expression.left.type, exp.DataType.Type.DOUBLE)
- self.assertEqual(expression.right.type, exp.DataType.Type.INT)
- self.assertEqual(expression.right.this.type, exp.DataType.Type.INT)
- self.assertEqual(expression.right.this.left.type, exp.DataType.Type.INT)
- self.assertEqual(expression.right.this.right.type, exp.DataType.Type.INT)
+ self.assertEqual(expression.type.this, exp.DataType.Type.DOUBLE)
+ self.assertEqual(expression.left.type.this, exp.DataType.Type.DOUBLE)
+ self.assertEqual(expression.right.type.this, exp.DataType.Type.INT)
+ self.assertEqual(expression.right.this.type.this, exp.DataType.Type.INT)
+ self.assertEqual(expression.right.this.left.type.this, exp.DataType.Type.INT)
+ self.assertEqual(expression.right.this.right.type.this, exp.DataType.Type.INT)
def test_derived_tables_column_annotation(self):
schema = {"x": {"cola": "INT"}, "y": {"cola": "FLOAT"}}
@@ -387,128 +389,169 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
"""
expression = annotate_types(parse_one(sql), schema=schema)
- self.assertEqual(expression.expressions[0].type, exp.DataType.Type.FLOAT) # a.cola AS cola
+ self.assertEqual(
+ expression.expressions[0].type.this, exp.DataType.Type.FLOAT
+ ) # a.cola AS cola
addition_alias = expression.args["from"].expressions[0].this.expressions[0]
- self.assertEqual(addition_alias.type, exp.DataType.Type.FLOAT) # x.cola + y.cola AS cola
+ self.assertEqual(
+ addition_alias.type.this, exp.DataType.Type.FLOAT
+ ) # x.cola + y.cola AS cola
addition = addition_alias.this
- self.assertEqual(addition.type, exp.DataType.Type.FLOAT)
- self.assertEqual(addition.this.type, exp.DataType.Type.INT)
- self.assertEqual(addition.expression.type, exp.DataType.Type.FLOAT)
+ self.assertEqual(addition.type.this, exp.DataType.Type.FLOAT)
+ self.assertEqual(addition.this.type.this, exp.DataType.Type.INT)
+ self.assertEqual(addition.expression.type.this, exp.DataType.Type.FLOAT)
def test_cte_column_annotation(self):
- schema = {"x": {"cola": "CHAR"}, "y": {"colb": "TEXT"}}
+ schema = {"x": {"cola": "CHAR"}, "y": {"colb": "TEXT", "colc": "BOOLEAN"}}
sql = """
WITH tbl AS (
- SELECT x.cola + 'bla' AS cola, y.colb AS colb
+ SELECT x.cola + 'bla' AS cola, y.colb AS colb, y.colc AS colc
FROM (
SELECT x.cola AS cola
FROM x AS x
) AS x
JOIN (
- SELECT y.colb AS colb
+ SELECT y.colb AS colb, y.colc AS colc
FROM y AS y
) AS y
)
SELECT tbl.cola + tbl.colb + 'foo' AS col
FROM tbl AS tbl
+ WHERE tbl.colc = True
"""
expression = annotate_types(parse_one(sql), schema=schema)
self.assertEqual(
- expression.expressions[0].type, exp.DataType.Type.TEXT
+ expression.expressions[0].type.this, exp.DataType.Type.TEXT
) # tbl.cola + tbl.colb + 'foo' AS col
outer_addition = expression.expressions[0].this # (tbl.cola + tbl.colb) + 'foo'
- self.assertEqual(outer_addition.type, exp.DataType.Type.TEXT)
- self.assertEqual(outer_addition.left.type, exp.DataType.Type.TEXT)
- self.assertEqual(outer_addition.right.type, exp.DataType.Type.VARCHAR)
+ self.assertEqual(outer_addition.type.this, exp.DataType.Type.TEXT)
+ self.assertEqual(outer_addition.left.type.this, exp.DataType.Type.TEXT)
+ self.assertEqual(outer_addition.right.type.this, exp.DataType.Type.VARCHAR)
inner_addition = expression.expressions[0].this.left # tbl.cola + tbl.colb
- self.assertEqual(inner_addition.left.type, exp.DataType.Type.VARCHAR)
- self.assertEqual(inner_addition.right.type, exp.DataType.Type.TEXT)
+ self.assertEqual(inner_addition.left.type.this, exp.DataType.Type.VARCHAR)
+ self.assertEqual(inner_addition.right.type.this, exp.DataType.Type.TEXT)
+
+ # WHERE tbl.colc = True
+ self.assertEqual(expression.args["where"].this.type.this, exp.DataType.Type.BOOLEAN)
cte_select = expression.args["with"].expressions[0].this
self.assertEqual(
- cte_select.expressions[0].type, exp.DataType.Type.VARCHAR
+ cte_select.expressions[0].type.this, exp.DataType.Type.VARCHAR
) # x.cola + 'bla' AS cola
- self.assertEqual(cte_select.expressions[1].type, exp.DataType.Type.TEXT) # y.colb AS colb
+ self.assertEqual(
+ cte_select.expressions[1].type.this, exp.DataType.Type.TEXT
+ ) # y.colb AS colb
+ self.assertEqual(
+ cte_select.expressions[2].type.this, exp.DataType.Type.BOOLEAN
+ ) # y.colc AS colc
cte_select_addition = cte_select.expressions[0].this # x.cola + 'bla'
- self.assertEqual(cte_select_addition.type, exp.DataType.Type.VARCHAR)
- self.assertEqual(cte_select_addition.left.type, exp.DataType.Type.CHAR)
- self.assertEqual(cte_select_addition.right.type, exp.DataType.Type.VARCHAR)
+ self.assertEqual(cte_select_addition.type.this, exp.DataType.Type.VARCHAR)
+ self.assertEqual(cte_select_addition.left.type.this, exp.DataType.Type.CHAR)
+ self.assertEqual(cte_select_addition.right.type.this, exp.DataType.Type.VARCHAR)
# Check that x.cola AS cola and y.colb AS colb have types CHAR and TEXT, respectively
for d, t in zip(
cte_select.find_all(exp.Subquery), [exp.DataType.Type.CHAR, exp.DataType.Type.TEXT]
):
- self.assertEqual(d.this.expressions[0].this.type, t)
+ self.assertEqual(d.this.expressions[0].this.type.this, t)
def test_function_annotation(self):
schema = {"x": {"cola": "VARCHAR", "colb": "CHAR"}}
sql = "SELECT x.cola || TRIM(x.colb) AS col FROM x AS x"
concat_expr_alias = annotate_types(parse_one(sql), schema=schema).expressions[0]
- self.assertEqual(concat_expr_alias.type, exp.DataType.Type.VARCHAR)
+ self.assertEqual(concat_expr_alias.type.this, exp.DataType.Type.VARCHAR)
concat_expr = concat_expr_alias.this
- self.assertEqual(concat_expr.type, exp.DataType.Type.VARCHAR)
- self.assertEqual(concat_expr.left.type, exp.DataType.Type.VARCHAR) # x.cola
- self.assertEqual(concat_expr.right.type, exp.DataType.Type.VARCHAR) # TRIM(x.colb)
- self.assertEqual(concat_expr.right.this.type, exp.DataType.Type.CHAR) # x.colb
+ self.assertEqual(concat_expr.type.this, exp.DataType.Type.VARCHAR)
+ self.assertEqual(concat_expr.left.type.this, exp.DataType.Type.VARCHAR) # x.cola
+ self.assertEqual(concat_expr.right.type.this, exp.DataType.Type.VARCHAR) # TRIM(x.colb)
+ self.assertEqual(concat_expr.right.this.type.this, exp.DataType.Type.CHAR) # x.colb
sql = "SELECT CASE WHEN 1=1 THEN x.cola ELSE x.colb END AS col FROM x AS x"
case_expr_alias = annotate_types(parse_one(sql), schema=schema).expressions[0]
- self.assertEqual(case_expr_alias.type, exp.DataType.Type.VARCHAR)
+ self.assertEqual(case_expr_alias.type.this, exp.DataType.Type.VARCHAR)
case_expr = case_expr_alias.this
- self.assertEqual(case_expr.type, exp.DataType.Type.VARCHAR)
- self.assertEqual(case_expr.args["default"].type, exp.DataType.Type.CHAR)
+ self.assertEqual(case_expr.type.this, exp.DataType.Type.VARCHAR)
+ self.assertEqual(case_expr.args["default"].type.this, exp.DataType.Type.CHAR)
case_ifs_expr = case_expr.args["ifs"][0]
- self.assertEqual(case_ifs_expr.type, exp.DataType.Type.VARCHAR)
- self.assertEqual(case_ifs_expr.args["true"].type, exp.DataType.Type.VARCHAR)
+ self.assertEqual(case_ifs_expr.type.this, exp.DataType.Type.VARCHAR)
+ self.assertEqual(case_ifs_expr.args["true"].type.this, exp.DataType.Type.VARCHAR)
def test_unknown_annotation(self):
schema = {"x": {"cola": "VARCHAR"}}
sql = "SELECT x.cola || SOME_ANONYMOUS_FUNC(x.cola) AS col FROM x AS x"
concat_expr_alias = annotate_types(parse_one(sql), schema=schema).expressions[0]
- self.assertEqual(concat_expr_alias.type, exp.DataType.Type.UNKNOWN)
+ self.assertEqual(concat_expr_alias.type.this, exp.DataType.Type.UNKNOWN)
concat_expr = concat_expr_alias.this
- self.assertEqual(concat_expr.type, exp.DataType.Type.UNKNOWN)
- self.assertEqual(concat_expr.left.type, exp.DataType.Type.VARCHAR) # x.cola
+ self.assertEqual(concat_expr.type.this, exp.DataType.Type.UNKNOWN)
+ self.assertEqual(concat_expr.left.type.this, exp.DataType.Type.VARCHAR) # x.cola
self.assertEqual(
- concat_expr.right.type, exp.DataType.Type.UNKNOWN
+ concat_expr.right.type.this, exp.DataType.Type.UNKNOWN
) # SOME_ANONYMOUS_FUNC(x.cola)
self.assertEqual(
- concat_expr.right.expressions[0].type, exp.DataType.Type.VARCHAR
+ concat_expr.right.expressions[0].type.this, exp.DataType.Type.VARCHAR
) # x.cola (arg)
def test_null_annotation(self):
expression = annotate_types(parse_one("SELECT NULL + 2 AS col")).expressions[0].this
- self.assertEqual(expression.left.type, exp.DataType.Type.NULL)
- self.assertEqual(expression.right.type, exp.DataType.Type.INT)
+ self.assertEqual(expression.left.type.this, exp.DataType.Type.NULL)
+ self.assertEqual(expression.right.type.this, exp.DataType.Type.INT)
# NULL <op> UNKNOWN should yield NULL
sql = "SELECT NULL || SOME_ANONYMOUS_FUNC() AS result"
concat_expr_alias = annotate_types(parse_one(sql)).expressions[0]
- self.assertEqual(concat_expr_alias.type, exp.DataType.Type.NULL)
+ self.assertEqual(concat_expr_alias.type.this, exp.DataType.Type.NULL)
concat_expr = concat_expr_alias.this
- self.assertEqual(concat_expr.type, exp.DataType.Type.NULL)
- self.assertEqual(concat_expr.left.type, exp.DataType.Type.NULL)
- self.assertEqual(concat_expr.right.type, exp.DataType.Type.UNKNOWN)
+ self.assertEqual(concat_expr.type.this, exp.DataType.Type.NULL)
+ self.assertEqual(concat_expr.left.type.this, exp.DataType.Type.NULL)
+ self.assertEqual(concat_expr.right.type.this, exp.DataType.Type.UNKNOWN)
def test_nullable_annotation(self):
nullable = exp.DataType.build("NULLABLE", expressions=exp.DataType.build("BOOLEAN"))
expression = annotate_types(parse_one("NULL AND FALSE"))
self.assertEqual(expression.type, nullable)
- self.assertEqual(expression.left.type, exp.DataType.Type.NULL)
- self.assertEqual(expression.right.type, exp.DataType.Type.BOOLEAN)
+ self.assertEqual(expression.left.type.this, exp.DataType.Type.NULL)
+ self.assertEqual(expression.right.type.this, exp.DataType.Type.BOOLEAN)
+
+ def test_predicate_annotation(self):
+ expression = annotate_types(parse_one("x BETWEEN a AND b"))
+ self.assertEqual(expression.type.this, exp.DataType.Type.BOOLEAN)
+
+ expression = annotate_types(parse_one("x IN (a, b, c, d)"))
+ self.assertEqual(expression.type.this, exp.DataType.Type.BOOLEAN)
+
+ def test_aggfunc_annotation(self):
+ schema = {"x": {"cola": "SMALLINT", "colb": "FLOAT", "colc": "TEXT", "cold": "DATE"}}
+
+ tests = {
+ ("AVG", "cola"): exp.DataType.Type.DOUBLE,
+ ("SUM", "cola"): exp.DataType.Type.BIGINT,
+ ("SUM", "colb"): exp.DataType.Type.DOUBLE,
+ ("MIN", "cola"): exp.DataType.Type.SMALLINT,
+ ("MIN", "colb"): exp.DataType.Type.FLOAT,
+ ("MAX", "colc"): exp.DataType.Type.TEXT,
+ ("MAX", "cold"): exp.DataType.Type.DATE,
+ ("COUNT", "colb"): exp.DataType.Type.BIGINT,
+ ("STDDEV", "cola"): exp.DataType.Type.DOUBLE,
+ }
+
+ for (func, col), target_type in tests.items():
+ expression = annotate_types(
+ parse_one(f"SELECT {func}(x.{col}) AS _col_0 FROM x AS x"), schema=schema
+ )
+ self.assertEqual(expression.expressions[0].type.this, target_type)