diff options
Diffstat (limited to 'tests/test_optimizer.py')
-rw-r--r-- | tests/test_optimizer.py | 155 |
1 files changed, 99 insertions, 56 deletions
diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index ecf581d..0c5f6cd 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -333,7 +333,7 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') for sql, target_type in tests.items(): expression = annotate_types(parse_one(sql)) - self.assertEqual(expression.find(exp.Literal).type, target_type) + self.assertEqual(expression.find(exp.Literal).type.this, target_type) def test_boolean_type_annotation(self): tests = { @@ -343,31 +343,33 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') for sql, target_type in tests.items(): expression = annotate_types(parse_one(sql)) - self.assertEqual(expression.find(exp.Boolean).type, target_type) + self.assertEqual(expression.find(exp.Boolean).type.this, target_type) def test_cast_type_annotation(self): expression = annotate_types(parse_one("CAST('2020-01-01' AS TIMESTAMPTZ(9))")) + self.assertEqual(expression.type.this, exp.DataType.Type.TIMESTAMPTZ) + self.assertEqual(expression.this.type.this, exp.DataType.Type.VARCHAR) + self.assertEqual(expression.args["to"].type.this, exp.DataType.Type.TIMESTAMPTZ) + self.assertEqual(expression.args["to"].expressions[0].type.this, exp.DataType.Type.INT) - self.assertEqual(expression.type, exp.DataType.Type.TIMESTAMPTZ) - self.assertEqual(expression.this.type, exp.DataType.Type.VARCHAR) - self.assertEqual(expression.args["to"].type, exp.DataType.Type.TIMESTAMPTZ) - self.assertEqual(expression.args["to"].expressions[0].type, exp.DataType.Type.INT) + expression = annotate_types(parse_one("ARRAY(1)::ARRAY<INT>")) + self.assertEqual(expression.type, parse_one("ARRAY<INT>", into=exp.DataType)) def test_cache_annotation(self): expression = annotate_types( parse_one("CACHE LAZY TABLE x OPTIONS('storageLevel' = 'value') AS SELECT 1") ) - self.assertEqual(expression.expression.expressions[0].type, exp.DataType.Type.INT) + self.assertEqual(expression.expression.expressions[0].type.this, exp.DataType.Type.INT) def test_binary_annotation(self): expression = annotate_types(parse_one("SELECT 0.0 + (2 + 3)")).expressions[0] - self.assertEqual(expression.type, exp.DataType.Type.DOUBLE) - self.assertEqual(expression.left.type, exp.DataType.Type.DOUBLE) - self.assertEqual(expression.right.type, exp.DataType.Type.INT) - self.assertEqual(expression.right.this.type, exp.DataType.Type.INT) - self.assertEqual(expression.right.this.left.type, exp.DataType.Type.INT) - self.assertEqual(expression.right.this.right.type, exp.DataType.Type.INT) + self.assertEqual(expression.type.this, exp.DataType.Type.DOUBLE) + self.assertEqual(expression.left.type.this, exp.DataType.Type.DOUBLE) + self.assertEqual(expression.right.type.this, exp.DataType.Type.INT) + self.assertEqual(expression.right.this.type.this, exp.DataType.Type.INT) + self.assertEqual(expression.right.this.left.type.this, exp.DataType.Type.INT) + self.assertEqual(expression.right.this.right.type.this, exp.DataType.Type.INT) def test_derived_tables_column_annotation(self): schema = {"x": {"cola": "INT"}, "y": {"cola": "FLOAT"}} @@ -387,128 +389,169 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') """ expression = annotate_types(parse_one(sql), schema=schema) - self.assertEqual(expression.expressions[0].type, exp.DataType.Type.FLOAT) # a.cola AS cola + self.assertEqual( + expression.expressions[0].type.this, exp.DataType.Type.FLOAT + ) # a.cola AS cola addition_alias = expression.args["from"].expressions[0].this.expressions[0] - self.assertEqual(addition_alias.type, exp.DataType.Type.FLOAT) # x.cola + y.cola AS cola + self.assertEqual( + addition_alias.type.this, exp.DataType.Type.FLOAT + ) # x.cola + y.cola AS cola addition = addition_alias.this - self.assertEqual(addition.type, exp.DataType.Type.FLOAT) - self.assertEqual(addition.this.type, exp.DataType.Type.INT) - self.assertEqual(addition.expression.type, exp.DataType.Type.FLOAT) + self.assertEqual(addition.type.this, exp.DataType.Type.FLOAT) + self.assertEqual(addition.this.type.this, exp.DataType.Type.INT) + self.assertEqual(addition.expression.type.this, exp.DataType.Type.FLOAT) def test_cte_column_annotation(self): - schema = {"x": {"cola": "CHAR"}, "y": {"colb": "TEXT"}} + schema = {"x": {"cola": "CHAR"}, "y": {"colb": "TEXT", "colc": "BOOLEAN"}} sql = """ WITH tbl AS ( - SELECT x.cola + 'bla' AS cola, y.colb AS colb + SELECT x.cola + 'bla' AS cola, y.colb AS colb, y.colc AS colc FROM ( SELECT x.cola AS cola FROM x AS x ) AS x JOIN ( - SELECT y.colb AS colb + SELECT y.colb AS colb, y.colc AS colc FROM y AS y ) AS y ) SELECT tbl.cola + tbl.colb + 'foo' AS col FROM tbl AS tbl + WHERE tbl.colc = True """ expression = annotate_types(parse_one(sql), schema=schema) self.assertEqual( - expression.expressions[0].type, exp.DataType.Type.TEXT + expression.expressions[0].type.this, exp.DataType.Type.TEXT ) # tbl.cola + tbl.colb + 'foo' AS col outer_addition = expression.expressions[0].this # (tbl.cola + tbl.colb) + 'foo' - self.assertEqual(outer_addition.type, exp.DataType.Type.TEXT) - self.assertEqual(outer_addition.left.type, exp.DataType.Type.TEXT) - self.assertEqual(outer_addition.right.type, exp.DataType.Type.VARCHAR) + self.assertEqual(outer_addition.type.this, exp.DataType.Type.TEXT) + self.assertEqual(outer_addition.left.type.this, exp.DataType.Type.TEXT) + self.assertEqual(outer_addition.right.type.this, exp.DataType.Type.VARCHAR) inner_addition = expression.expressions[0].this.left # tbl.cola + tbl.colb - self.assertEqual(inner_addition.left.type, exp.DataType.Type.VARCHAR) - self.assertEqual(inner_addition.right.type, exp.DataType.Type.TEXT) + self.assertEqual(inner_addition.left.type.this, exp.DataType.Type.VARCHAR) + self.assertEqual(inner_addition.right.type.this, exp.DataType.Type.TEXT) + + # WHERE tbl.colc = True + self.assertEqual(expression.args["where"].this.type.this, exp.DataType.Type.BOOLEAN) cte_select = expression.args["with"].expressions[0].this self.assertEqual( - cte_select.expressions[0].type, exp.DataType.Type.VARCHAR + cte_select.expressions[0].type.this, exp.DataType.Type.VARCHAR ) # x.cola + 'bla' AS cola - self.assertEqual(cte_select.expressions[1].type, exp.DataType.Type.TEXT) # y.colb AS colb + self.assertEqual( + cte_select.expressions[1].type.this, exp.DataType.Type.TEXT + ) # y.colb AS colb + self.assertEqual( + cte_select.expressions[2].type.this, exp.DataType.Type.BOOLEAN + ) # y.colc AS colc cte_select_addition = cte_select.expressions[0].this # x.cola + 'bla' - self.assertEqual(cte_select_addition.type, exp.DataType.Type.VARCHAR) - self.assertEqual(cte_select_addition.left.type, exp.DataType.Type.CHAR) - self.assertEqual(cte_select_addition.right.type, exp.DataType.Type.VARCHAR) + self.assertEqual(cte_select_addition.type.this, exp.DataType.Type.VARCHAR) + self.assertEqual(cte_select_addition.left.type.this, exp.DataType.Type.CHAR) + self.assertEqual(cte_select_addition.right.type.this, exp.DataType.Type.VARCHAR) # Check that x.cola AS cola and y.colb AS colb have types CHAR and TEXT, respectively for d, t in zip( cte_select.find_all(exp.Subquery), [exp.DataType.Type.CHAR, exp.DataType.Type.TEXT] ): - self.assertEqual(d.this.expressions[0].this.type, t) + self.assertEqual(d.this.expressions[0].this.type.this, t) def test_function_annotation(self): schema = {"x": {"cola": "VARCHAR", "colb": "CHAR"}} sql = "SELECT x.cola || TRIM(x.colb) AS col FROM x AS x" concat_expr_alias = annotate_types(parse_one(sql), schema=schema).expressions[0] - self.assertEqual(concat_expr_alias.type, exp.DataType.Type.VARCHAR) + self.assertEqual(concat_expr_alias.type.this, exp.DataType.Type.VARCHAR) concat_expr = concat_expr_alias.this - self.assertEqual(concat_expr.type, exp.DataType.Type.VARCHAR) - self.assertEqual(concat_expr.left.type, exp.DataType.Type.VARCHAR) # x.cola - self.assertEqual(concat_expr.right.type, exp.DataType.Type.VARCHAR) # TRIM(x.colb) - self.assertEqual(concat_expr.right.this.type, exp.DataType.Type.CHAR) # x.colb + self.assertEqual(concat_expr.type.this, exp.DataType.Type.VARCHAR) + self.assertEqual(concat_expr.left.type.this, exp.DataType.Type.VARCHAR) # x.cola + self.assertEqual(concat_expr.right.type.this, exp.DataType.Type.VARCHAR) # TRIM(x.colb) + self.assertEqual(concat_expr.right.this.type.this, exp.DataType.Type.CHAR) # x.colb sql = "SELECT CASE WHEN 1=1 THEN x.cola ELSE x.colb END AS col FROM x AS x" case_expr_alias = annotate_types(parse_one(sql), schema=schema).expressions[0] - self.assertEqual(case_expr_alias.type, exp.DataType.Type.VARCHAR) + self.assertEqual(case_expr_alias.type.this, exp.DataType.Type.VARCHAR) case_expr = case_expr_alias.this - self.assertEqual(case_expr.type, exp.DataType.Type.VARCHAR) - self.assertEqual(case_expr.args["default"].type, exp.DataType.Type.CHAR) + self.assertEqual(case_expr.type.this, exp.DataType.Type.VARCHAR) + self.assertEqual(case_expr.args["default"].type.this, exp.DataType.Type.CHAR) case_ifs_expr = case_expr.args["ifs"][0] - self.assertEqual(case_ifs_expr.type, exp.DataType.Type.VARCHAR) - self.assertEqual(case_ifs_expr.args["true"].type, exp.DataType.Type.VARCHAR) + self.assertEqual(case_ifs_expr.type.this, exp.DataType.Type.VARCHAR) + self.assertEqual(case_ifs_expr.args["true"].type.this, exp.DataType.Type.VARCHAR) def test_unknown_annotation(self): schema = {"x": {"cola": "VARCHAR"}} sql = "SELECT x.cola || SOME_ANONYMOUS_FUNC(x.cola) AS col FROM x AS x" concat_expr_alias = annotate_types(parse_one(sql), schema=schema).expressions[0] - self.assertEqual(concat_expr_alias.type, exp.DataType.Type.UNKNOWN) + self.assertEqual(concat_expr_alias.type.this, exp.DataType.Type.UNKNOWN) concat_expr = concat_expr_alias.this - self.assertEqual(concat_expr.type, exp.DataType.Type.UNKNOWN) - self.assertEqual(concat_expr.left.type, exp.DataType.Type.VARCHAR) # x.cola + self.assertEqual(concat_expr.type.this, exp.DataType.Type.UNKNOWN) + self.assertEqual(concat_expr.left.type.this, exp.DataType.Type.VARCHAR) # x.cola self.assertEqual( - concat_expr.right.type, exp.DataType.Type.UNKNOWN + concat_expr.right.type.this, exp.DataType.Type.UNKNOWN ) # SOME_ANONYMOUS_FUNC(x.cola) self.assertEqual( - concat_expr.right.expressions[0].type, exp.DataType.Type.VARCHAR + concat_expr.right.expressions[0].type.this, exp.DataType.Type.VARCHAR ) # x.cola (arg) def test_null_annotation(self): expression = annotate_types(parse_one("SELECT NULL + 2 AS col")).expressions[0].this - self.assertEqual(expression.left.type, exp.DataType.Type.NULL) - self.assertEqual(expression.right.type, exp.DataType.Type.INT) + self.assertEqual(expression.left.type.this, exp.DataType.Type.NULL) + self.assertEqual(expression.right.type.this, exp.DataType.Type.INT) # NULL <op> UNKNOWN should yield NULL sql = "SELECT NULL || SOME_ANONYMOUS_FUNC() AS result" concat_expr_alias = annotate_types(parse_one(sql)).expressions[0] - self.assertEqual(concat_expr_alias.type, exp.DataType.Type.NULL) + self.assertEqual(concat_expr_alias.type.this, exp.DataType.Type.NULL) concat_expr = concat_expr_alias.this - self.assertEqual(concat_expr.type, exp.DataType.Type.NULL) - self.assertEqual(concat_expr.left.type, exp.DataType.Type.NULL) - self.assertEqual(concat_expr.right.type, exp.DataType.Type.UNKNOWN) + self.assertEqual(concat_expr.type.this, exp.DataType.Type.NULL) + self.assertEqual(concat_expr.left.type.this, exp.DataType.Type.NULL) + self.assertEqual(concat_expr.right.type.this, exp.DataType.Type.UNKNOWN) def test_nullable_annotation(self): nullable = exp.DataType.build("NULLABLE", expressions=exp.DataType.build("BOOLEAN")) expression = annotate_types(parse_one("NULL AND FALSE")) self.assertEqual(expression.type, nullable) - self.assertEqual(expression.left.type, exp.DataType.Type.NULL) - self.assertEqual(expression.right.type, exp.DataType.Type.BOOLEAN) + self.assertEqual(expression.left.type.this, exp.DataType.Type.NULL) + self.assertEqual(expression.right.type.this, exp.DataType.Type.BOOLEAN) + + def test_predicate_annotation(self): + expression = annotate_types(parse_one("x BETWEEN a AND b")) + self.assertEqual(expression.type.this, exp.DataType.Type.BOOLEAN) + + expression = annotate_types(parse_one("x IN (a, b, c, d)")) + self.assertEqual(expression.type.this, exp.DataType.Type.BOOLEAN) + + def test_aggfunc_annotation(self): + schema = {"x": {"cola": "SMALLINT", "colb": "FLOAT", "colc": "TEXT", "cold": "DATE"}} + + tests = { + ("AVG", "cola"): exp.DataType.Type.DOUBLE, + ("SUM", "cola"): exp.DataType.Type.BIGINT, + ("SUM", "colb"): exp.DataType.Type.DOUBLE, + ("MIN", "cola"): exp.DataType.Type.SMALLINT, + ("MIN", "colb"): exp.DataType.Type.FLOAT, + ("MAX", "colc"): exp.DataType.Type.TEXT, + ("MAX", "cold"): exp.DataType.Type.DATE, + ("COUNT", "colb"): exp.DataType.Type.BIGINT, + ("STDDEV", "cola"): exp.DataType.Type.DOUBLE, + } + + for (func, col), target_type in tests.items(): + expression = annotate_types( + parse_one(f"SELECT {func}(x.{col}) AS _col_0 FROM x AS x"), schema=schema + ) + self.assertEqual(expression.expressions[0].type.this, target_type) |