diff options
Diffstat (limited to 'tests/test_executor.py')
-rw-r--r-- | tests/test_executor.py | 71 |
1 files changed, 69 insertions, 2 deletions
diff --git a/tests/test_executor.py b/tests/test_executor.py index 2c4d7cd..9d452e4 100644 --- a/tests/test_executor.py +++ b/tests/test_executor.py @@ -68,13 +68,13 @@ class TestExecutor(unittest.TestCase): def test_execute_tpch(self): def to_csv(expression): - if isinstance(expression, exp.Table): + if isinstance(expression, exp.Table) and expression.name not in ("revenue"): return parse_one( f"READ_CSV('{DIR}{expression.name}.csv.gz', 'delimiter', '|') AS {expression.alias_or_name}" ) return expression - for i, (sql, _) in enumerate(self.sqls[0:7]): + for i, (sql, _) in enumerate(self.sqls[0:16]): with self.subTest(f"tpch-h {i + 1}"): a = self.cached_execute(sql) sql = parse_one(sql).transform(to_csv).sql(pretty=True) @@ -165,6 +165,39 @@ class TestExecutor(unittest.TestCase): ["a"], [("a",)], ), + ( + "SELECT DISTINCT a FROM (SELECT 1 AS a UNION ALL SELECT 1 AS a)", + ["a"], + [(1,)], + ), + ( + "SELECT DISTINCT a, SUM(b) AS b " + "FROM (SELECT 'a' AS a, 1 AS b UNION ALL SELECT 'a' AS a, 2 AS b UNION ALL SELECT 'b' AS a, 1 AS b) " + "GROUP BY a " + "LIMIT 1", + ["a", "b"], + [("a", 3)], + ), + ( + "SELECT COUNT(1) AS a FROM (SELECT 1)", + ["a"], + [(1,)], + ), + ( + "SELECT COUNT(1) AS a FROM (SELECT 1) LIMIT 0", + ["a"], + [], + ), + ( + "SELECT a FROM x GROUP BY a LIMIT 0", + ["a"], + [], + ), + ( + "SELECT a FROM x LIMIT 0", + ["a"], + [], + ), ]: with self.subTest(sql): result = execute(sql, schema=schema, tables=tables) @@ -346,6 +379,28 @@ class TestExecutor(unittest.TestCase): ], ) + def test_execute_subqueries(self): + tables = { + "table": [ + {"a": 1, "b": 1}, + {"a": 2, "b": 2}, + ], + } + + self.assertEqual( + execute( + """ + SELECT * + FROM table + WHERE a = (SELECT MAX(a) FROM table) + """, + tables=tables, + ).rows, + [ + (2, 2), + ], + ) + def test_table_depth_mismatch(self): tables = {"table": []} schema = {"db": {"table": {"col": "VARCHAR"}}} @@ -401,6 +456,7 @@ class TestExecutor(unittest.TestCase): ("SELECT CONCAT('a', 'b') AS x", ["x"], [("ab",)]), ("SELECT 1 AS x, 2 AS y", ["x", "y"], [(1, 2)]), ("SELECT 'foo' LIMIT 1", ["_col_0"], [("foo",)]), + ("SELECT SUM(x) FROM (SELECT 1 AS x WHERE FALSE)", ["_col_0"], [(0,)]), ]: result = execute(sql) self.assertEqual(result.columns, tuple(cols)) @@ -462,7 +518,18 @@ class TestExecutor(unittest.TestCase): ("IF(false, 1, 0)", 0), ("CASE WHEN 0 = 1 THEN 'foo' ELSE 'bar' END", "bar"), ("CAST('2022-01-01' AS DATE) + INTERVAL '1' DAY", date(2022, 1, 2)), + ("1 IN (1, 2, 3)", True), + ("1 IN (2, 3)", False), + ("NULL IS NULL", True), + ("NULL IS NOT NULL", False), + ("NULL = NULL", None), + ("NULL <> NULL", None), ]: with self.subTest(sql): result = execute(f"SELECT {sql}") self.assertEqual(result.rows, [(expected,)]) + + def test_case_sensitivity(self): + result = execute("SELECT A AS A FROM X", tables={"x": [{"a": 1}]}) + self.assertEqual(result.columns, ("A",)) + self.assertEqual(result.rows, [(1,)]) |