summaryrefslogtreecommitdiffstats
path: root/tests/test_executor.py
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2022-12-02 09:16:29 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2022-12-02 09:16:29 +0000
commit1a60bbae98d3b530924a6807a55f8250de19ea86 (patch)
tree87d3000f271a6604fff43db188731229aed918a8 /tests/test_executor.py
parentAdding upstream version 10.0.8. (diff)
downloadsqlglot-1a60bbae98d3b530924a6807a55f8250de19ea86.tar.xz
sqlglot-1a60bbae98d3b530924a6807a55f8250de19ea86.zip
Adding upstream version 10.1.3.upstream/10.1.3
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'tests/test_executor.py')
-rw-r--r--tests/test_executor.py71
1 files changed, 69 insertions, 2 deletions
diff --git a/tests/test_executor.py b/tests/test_executor.py
index 2c4d7cd..9d452e4 100644
--- a/tests/test_executor.py
+++ b/tests/test_executor.py
@@ -68,13 +68,13 @@ class TestExecutor(unittest.TestCase):
def test_execute_tpch(self):
def to_csv(expression):
- if isinstance(expression, exp.Table):
+ if isinstance(expression, exp.Table) and expression.name not in ("revenue"):
return parse_one(
f"READ_CSV('{DIR}{expression.name}.csv.gz', 'delimiter', '|') AS {expression.alias_or_name}"
)
return expression
- for i, (sql, _) in enumerate(self.sqls[0:7]):
+ for i, (sql, _) in enumerate(self.sqls[0:16]):
with self.subTest(f"tpch-h {i + 1}"):
a = self.cached_execute(sql)
sql = parse_one(sql).transform(to_csv).sql(pretty=True)
@@ -165,6 +165,39 @@ class TestExecutor(unittest.TestCase):
["a"],
[("a",)],
),
+ (
+ "SELECT DISTINCT a FROM (SELECT 1 AS a UNION ALL SELECT 1 AS a)",
+ ["a"],
+ [(1,)],
+ ),
+ (
+ "SELECT DISTINCT a, SUM(b) AS b "
+ "FROM (SELECT 'a' AS a, 1 AS b UNION ALL SELECT 'a' AS a, 2 AS b UNION ALL SELECT 'b' AS a, 1 AS b) "
+ "GROUP BY a "
+ "LIMIT 1",
+ ["a", "b"],
+ [("a", 3)],
+ ),
+ (
+ "SELECT COUNT(1) AS a FROM (SELECT 1)",
+ ["a"],
+ [(1,)],
+ ),
+ (
+ "SELECT COUNT(1) AS a FROM (SELECT 1) LIMIT 0",
+ ["a"],
+ [],
+ ),
+ (
+ "SELECT a FROM x GROUP BY a LIMIT 0",
+ ["a"],
+ [],
+ ),
+ (
+ "SELECT a FROM x LIMIT 0",
+ ["a"],
+ [],
+ ),
]:
with self.subTest(sql):
result = execute(sql, schema=schema, tables=tables)
@@ -346,6 +379,28 @@ class TestExecutor(unittest.TestCase):
],
)
+ def test_execute_subqueries(self):
+ tables = {
+ "table": [
+ {"a": 1, "b": 1},
+ {"a": 2, "b": 2},
+ ],
+ }
+
+ self.assertEqual(
+ execute(
+ """
+ SELECT *
+ FROM table
+ WHERE a = (SELECT MAX(a) FROM table)
+ """,
+ tables=tables,
+ ).rows,
+ [
+ (2, 2),
+ ],
+ )
+
def test_table_depth_mismatch(self):
tables = {"table": []}
schema = {"db": {"table": {"col": "VARCHAR"}}}
@@ -401,6 +456,7 @@ class TestExecutor(unittest.TestCase):
("SELECT CONCAT('a', 'b') AS x", ["x"], [("ab",)]),
("SELECT 1 AS x, 2 AS y", ["x", "y"], [(1, 2)]),
("SELECT 'foo' LIMIT 1", ["_col_0"], [("foo",)]),
+ ("SELECT SUM(x) FROM (SELECT 1 AS x WHERE FALSE)", ["_col_0"], [(0,)]),
]:
result = execute(sql)
self.assertEqual(result.columns, tuple(cols))
@@ -462,7 +518,18 @@ class TestExecutor(unittest.TestCase):
("IF(false, 1, 0)", 0),
("CASE WHEN 0 = 1 THEN 'foo' ELSE 'bar' END", "bar"),
("CAST('2022-01-01' AS DATE) + INTERVAL '1' DAY", date(2022, 1, 2)),
+ ("1 IN (1, 2, 3)", True),
+ ("1 IN (2, 3)", False),
+ ("NULL IS NULL", True),
+ ("NULL IS NOT NULL", False),
+ ("NULL = NULL", None),
+ ("NULL <> NULL", None),
]:
with self.subTest(sql):
result = execute(f"SELECT {sql}")
self.assertEqual(result.rows, [(expected,)])
+
+ def test_case_sensitivity(self):
+ result = execute("SELECT A AS A FROM X", tables={"x": [{"a": 1}]})
+ self.assertEqual(result.columns, ("A",))
+ self.assertEqual(result.rows, [(1,)])