summaryrefslogtreecommitdiffstats
path: root/tests/test_optimizer.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_optimizer.py')
-rw-r--r--tests/test_optimizer.py82
1 files changed, 64 insertions, 18 deletions
diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py
index fd95577..141203d 100644
--- a/tests/test_optimizer.py
+++ b/tests/test_optimizer.py
@@ -103,6 +103,10 @@ class TestOptimizer(unittest.TestCase):
"d": "TEXT",
"e": "TEXT",
},
+ "temporal": {
+ "d": "DATE",
+ "t": "DATETIME",
+ },
}
def check_file(self, file, func, pretty=False, execute=False, set_dialect=False, **kwargs):
@@ -179,6 +183,18 @@ class TestOptimizer(unittest.TestCase):
)
def test_qualify_tables(self):
+ self.assertEqual(
+ optimizer.qualify_tables.qualify_tables(
+ parse_one("select a from b"), catalog="catalog"
+ ).sql(),
+ "SELECT a FROM b AS b",
+ )
+
+ self.assertEqual(
+ optimizer.qualify_tables.qualify_tables(parse_one("select a from b"), db='"DB"').sql(),
+ 'SELECT a FROM "DB".b AS b',
+ )
+
self.check_file(
"qualify_tables",
optimizer.qualify_tables.qualify_tables,
@@ -282,6 +298,13 @@ class TestOptimizer(unittest.TestCase):
self.assertEqual(optimizer.normalize_identifiers.normalize_identifiers("a%").sql(), '"a%"')
+ def test_quote_identifiers(self):
+ self.check_file(
+ "quote_identifiers",
+ optimizer.qualify_columns.quote_identifiers,
+ set_dialect=True,
+ )
+
def test_pushdown_projection(self):
self.check_file("pushdown_projections", pushdown_projections, schema=self.schema)
@@ -300,8 +323,8 @@ class TestOptimizer(unittest.TestCase):
safe_concat = parse_one("CONCAT('a', x, 'b', 'c')")
simplified_safe_concat = optimizer.simplify.simplify(safe_concat)
- self.assertIs(type(simplified_concat), exp.Concat)
- self.assertIs(type(simplified_safe_concat), exp.SafeConcat)
+ self.assertEqual(simplified_concat.args["safe"], False)
+ self.assertEqual(simplified_safe_concat.args["safe"], True)
self.assertEqual("CONCAT('a', x, 'bc')", simplified_concat.sql(dialect="presto"))
self.assertEqual("CONCAT('a', x, 'bc')", simplified_safe_concat.sql())
@@ -561,6 +584,19 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
self.assertEqual(expression.right.this.left.type.this, exp.DataType.Type.INT)
self.assertEqual(expression.right.this.right.type.this, exp.DataType.Type.INT)
+ for numeric_type in ("BIGINT", "DOUBLE", "INT"):
+ query = f"SELECT '1' + CAST(x AS {numeric_type})"
+ expression = annotate_types(parse_one(query)).expressions[0]
+ self.assertEqual(expression.type, exp.DataType.build(numeric_type))
+
+ def test_typeddiv_annotation(self):
+ expressions = annotate_types(
+ parse_one("SELECT 2 / 3, 2 / 3.0", dialect="presto")
+ ).expressions
+
+ self.assertEqual(expressions[0].type.this, exp.DataType.Type.BIGINT)
+ self.assertEqual(expressions[1].type.this, exp.DataType.Type.DOUBLE)
+
def test_bracket_annotation(self):
expression = annotate_types(parse_one("SELECT A[:]")).expressions[0]
@@ -609,45 +645,60 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
"b": "DATETIME",
}
}
- for sql, expected_type, *expected_sql in [
+ for sql, expected_type in [
(
"SELECT '2023-01-01' + INTERVAL '1' DAY",
exp.DataType.Type.DATE,
- "SELECT CAST('2023-01-01' AS DATE) + INTERVAL '1' DAY",
),
(
"SELECT '2023-01-01' + INTERVAL '1' HOUR",
exp.DataType.Type.DATETIME,
- "SELECT CAST('2023-01-01' AS DATETIME) + INTERVAL '1' HOUR",
),
(
"SELECT '2023-01-01 00:00:01' + INTERVAL '1' HOUR",
exp.DataType.Type.DATETIME,
- "SELECT CAST('2023-01-01 00:00:01' AS DATETIME) + INTERVAL '1' HOUR",
),
("SELECT 'nonsense' + INTERVAL '1' DAY", exp.DataType.Type.UNKNOWN),
("SELECT x.a + INTERVAL '1' DAY FROM x AS x", exp.DataType.Type.DATE),
- ("SELECT x.a + INTERVAL '1' HOUR FROM x AS x", exp.DataType.Type.DATETIME),
+ (
+ "SELECT x.a + INTERVAL '1' HOUR FROM x AS x",
+ exp.DataType.Type.DATETIME,
+ ),
("SELECT x.b + INTERVAL '1' DAY FROM x AS x", exp.DataType.Type.DATETIME),
("SELECT x.b + INTERVAL '1' HOUR FROM x AS x", exp.DataType.Type.DATETIME),
(
"SELECT DATE_ADD('2023-01-01', 1, 'DAY')",
exp.DataType.Type.DATE,
- "SELECT DATE_ADD(CAST('2023-01-01' AS DATE), 1, 'DAY')",
),
(
"SELECT DATE_ADD('2023-01-01 00:00:00', 1, 'DAY')",
exp.DataType.Type.DATETIME,
- "SELECT DATE_ADD(CAST('2023-01-01 00:00:00' AS DATETIME), 1, 'DAY')",
),
("SELECT DATE_ADD(x.a, 1, 'DAY') FROM x AS x", exp.DataType.Type.DATE),
- ("SELECT DATE_ADD(x.a, 1, 'HOUR') FROM x AS x", exp.DataType.Type.DATETIME),
+ (
+ "SELECT DATE_ADD(x.a, 1, 'HOUR') FROM x AS x",
+ exp.DataType.Type.DATETIME,
+ ),
("SELECT DATE_ADD(x.b, 1, 'DAY') FROM x AS x", exp.DataType.Type.DATETIME),
+ ("SELECT DATE_TRUNC('DAY', x.a) FROM x AS x", exp.DataType.Type.DATE),
+ ("SELECT DATE_TRUNC('DAY', x.b) FROM x AS x", exp.DataType.Type.DATETIME),
+ (
+ "SELECT DATE_TRUNC('SECOND', x.a) FROM x AS x",
+ exp.DataType.Type.DATETIME,
+ ),
+ (
+ "SELECT DATE_TRUNC('DAY', '2023-01-01') FROM x AS x",
+ exp.DataType.Type.DATE,
+ ),
+ (
+ "SELECT DATEDIFF('2023-01-01', '2023-01-02', DAY) FROM x AS x",
+ exp.DataType.Type.INT,
+ ),
]:
with self.subTest(sql):
expression = annotate_types(parse_one(sql), schema=schema)
self.assertEqual(expected_type, expression.expressions[0].type.this)
- self.assertEqual(expected_sql[0] if expected_sql else sql, expression.sql())
+ self.assertEqual(sql, expression.sql())
def test_lateral_annotation(self):
expression = optimizer.optimize(
@@ -843,6 +894,8 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
("MAX", "cold"): exp.DataType.Type.DATE,
("COUNT", "colb"): exp.DataType.Type.BIGINT,
("STDDEV", "cola"): exp.DataType.Type.DOUBLE,
+ ("ABS", "cola"): exp.DataType.Type.SMALLINT,
+ ("ABS", "colb"): exp.DataType.Type.FLOAT,
}
for (func, col), target_type in tests.items():
@@ -989,10 +1042,3 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
query = parse_one("select a.b:c from d", read="snowflake")
qualified = optimizer.qualify.qualify(query)
self.assertEqual(qualified.expressions[0].alias, "c")
-
- def test_qualify_tables_no_schema(self):
- query = parse_one("select a from b")
- self.assertEqual(
- optimizer.qualify_tables.qualify_tables(query, catalog="catalog").sql(),
- "SELECT a FROM b AS b",
- )