summaryrefslogtreecommitdiffstats
path: root/tests/test_executor.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--tests/test_executor.py403
1 files changed, 396 insertions, 7 deletions
diff --git a/tests/test_executor.py b/tests/test_executor.py
index 49805b9..2c4d7cd 100644
--- a/tests/test_executor.py
+++ b/tests/test_executor.py
@@ -1,12 +1,15 @@
import unittest
+from datetime import date
import duckdb
import pandas as pd
from pandas.testing import assert_frame_equal
from sqlglot import exp, parse_one
+from sqlglot.errors import ExecuteError
from sqlglot.executor import execute
from sqlglot.executor.python import Python
+from sqlglot.executor.table import Table, ensure_tables
from tests.helpers import (
FIXTURES_DIR,
SKIP_INTEGRATION,
@@ -67,13 +70,399 @@ class TestExecutor(unittest.TestCase):
def to_csv(expression):
if isinstance(expression, exp.Table):
return parse_one(
- f"READ_CSV('{DIR}{expression.name}.csv.gz', 'delimiter', '|') AS {expression.name}"
+ f"READ_CSV('{DIR}{expression.name}.csv.gz', 'delimiter', '|') AS {expression.alias_or_name}"
)
return expression
- for sql, _ in self.sqls[0:3]:
- a = self.cached_execute(sql)
- sql = parse_one(sql).transform(to_csv).sql(pretty=True)
- table = execute(sql, TPCH_SCHEMA)
- b = pd.DataFrame(table.rows, columns=table.columns)
- assert_frame_equal(a, b, check_dtype=False)
+ for i, (sql, _) in enumerate(self.sqls[0:7]):
+ with self.subTest(f"tpch-h {i + 1}"):
+ a = self.cached_execute(sql)
+ sql = parse_one(sql).transform(to_csv).sql(pretty=True)
+ table = execute(sql, TPCH_SCHEMA)
+ b = pd.DataFrame(table.rows, columns=table.columns)
+ assert_frame_equal(a, b, check_dtype=False)
+
+ def test_execute_callable(self):
+ tables = {
+ "x": [
+ {"a": "a", "b": "d"},
+ {"a": "b", "b": "e"},
+ {"a": "c", "b": "f"},
+ ],
+ "y": [
+ {"b": "d", "c": "g"},
+ {"b": "e", "c": "h"},
+ {"b": "f", "c": "i"},
+ ],
+ "z": [],
+ }
+ schema = {
+ "x": {
+ "a": "VARCHAR",
+ "b": "VARCHAR",
+ },
+ "y": {
+ "b": "VARCHAR",
+ "c": "VARCHAR",
+ },
+ "z": {"d": "VARCHAR"},
+ }
+
+ for sql, cols, rows in [
+ ("SELECT * FROM x", ["a", "b"], [("a", "d"), ("b", "e"), ("c", "f")]),
+ (
+ "SELECT * FROM x JOIN y ON x.b = y.b",
+ ["a", "b", "b", "c"],
+ [("a", "d", "d", "g"), ("b", "e", "e", "h"), ("c", "f", "f", "i")],
+ ),
+ (
+ "SELECT j.c AS d FROM x AS i JOIN y AS j ON i.b = j.b",
+ ["d"],
+ [("g",), ("h",), ("i",)],
+ ),
+ (
+ "SELECT CONCAT(x.a, y.c) FROM x JOIN y ON x.b = y.b WHERE y.b = 'e'",
+ ["_col_0"],
+ [("bh",)],
+ ),
+ (
+ "SELECT * FROM x JOIN y ON x.b = y.b WHERE y.b = 'e'",
+ ["a", "b", "b", "c"],
+ [("b", "e", "e", "h")],
+ ),
+ (
+ "SELECT * FROM z",
+ ["d"],
+ [],
+ ),
+ (
+ "SELECT d FROM z ORDER BY d",
+ ["d"],
+ [],
+ ),
+ (
+ "SELECT a FROM x WHERE x.a <> 'b'",
+ ["a"],
+ [("a",), ("c",)],
+ ),
+ (
+ "SELECT a AS i FROM x ORDER BY a",
+ ["i"],
+ [("a",), ("b",), ("c",)],
+ ),
+ (
+ "SELECT a AS i FROM x ORDER BY i",
+ ["i"],
+ [("a",), ("b",), ("c",)],
+ ),
+ (
+ "SELECT 100 - ORD(a) AS a, a AS i FROM x ORDER BY a",
+ ["a", "i"],
+ [(1, "c"), (2, "b"), (3, "a")],
+ ),
+ (
+ "SELECT a /* test */ FROM x LIMIT 1",
+ ["a"],
+ [("a",)],
+ ),
+ ]:
+ with self.subTest(sql):
+ result = execute(sql, schema=schema, tables=tables)
+ self.assertEqual(result.columns, tuple(cols))
+ self.assertEqual(result.rows, rows)
+
+ def test_set_operations(self):
+ tables = {
+ "x": [
+ {"a": "a"},
+ {"a": "b"},
+ {"a": "c"},
+ ],
+ "y": [
+ {"a": "b"},
+ {"a": "c"},
+ {"a": "d"},
+ ],
+ }
+ schema = {
+ "x": {
+ "a": "VARCHAR",
+ },
+ "y": {
+ "a": "VARCHAR",
+ },
+ }
+
+ for sql, cols, rows in [
+ (
+ "SELECT a FROM x UNION ALL SELECT a FROM y",
+ ["a"],
+ [("a",), ("b",), ("c",), ("b",), ("c",), ("d",)],
+ ),
+ (
+ "SELECT a FROM x UNION SELECT a FROM y",
+ ["a"],
+ [("a",), ("b",), ("c",), ("d",)],
+ ),
+ (
+ "SELECT a FROM x EXCEPT SELECT a FROM y",
+ ["a"],
+ [("a",)],
+ ),
+ (
+ "SELECT a FROM x INTERSECT SELECT a FROM y",
+ ["a"],
+ [("b",), ("c",)],
+ ),
+ (
+ """SELECT i.a
+ FROM (
+ SELECT a FROM x UNION SELECT a FROM y
+ ) AS i
+ JOIN (
+ SELECT a FROM x UNION SELECT a FROM y
+ ) AS j
+ ON i.a = j.a""",
+ ["a"],
+ [("a",), ("b",), ("c",), ("d",)],
+ ),
+ (
+ "SELECT 1 AS a UNION SELECT 2 AS a UNION SELECT 3 AS a",
+ ["a"],
+ [(1,), (2,), (3,)],
+ ),
+ ]:
+ with self.subTest(sql):
+ result = execute(sql, schema=schema, tables=tables)
+ self.assertEqual(result.columns, tuple(cols))
+ self.assertEqual(set(result.rows), set(rows))
+
+ def test_execute_catalog_db_table(self):
+ tables = {
+ "catalog": {
+ "db": {
+ "x": [
+ {"a": "a"},
+ {"a": "b"},
+ {"a": "c"},
+ ],
+ }
+ }
+ }
+ schema = {
+ "catalog": {
+ "db": {
+ "x": {
+ "a": "VARCHAR",
+ }
+ }
+ }
+ }
+ result1 = execute("SELECT * FROM x", schema=schema, tables=tables)
+ result2 = execute("SELECT * FROM catalog.db.x", schema=schema, tables=tables)
+ assert result1.columns == result2.columns
+ assert result1.rows == result2.rows
+
+ def test_execute_tables(self):
+ tables = {
+ "sushi": [
+ {"id": 1, "price": 1.0},
+ {"id": 2, "price": 2.0},
+ {"id": 3, "price": 3.0},
+ ],
+ "order_items": [
+ {"sushi_id": 1, "order_id": 1},
+ {"sushi_id": 1, "order_id": 1},
+ {"sushi_id": 2, "order_id": 1},
+ {"sushi_id": 3, "order_id": 2},
+ ],
+ "orders": [
+ {"id": 1, "user_id": 1},
+ {"id": 2, "user_id": 2},
+ ],
+ }
+
+ self.assertEqual(
+ execute(
+ """
+ SELECT
+ o.user_id,
+ SUM(s.price) AS price
+ FROM orders o
+ JOIN order_items i
+ ON o.id = i.order_id
+ JOIN sushi s
+ ON i.sushi_id = s.id
+ GROUP BY o.user_id
+ """,
+ tables=tables,
+ ).rows,
+ [
+ (1, 4.0),
+ (2, 3.0),
+ ],
+ )
+
+ self.assertEqual(
+ execute(
+ """
+ SELECT
+ o.id, x.*
+ FROM orders o
+ LEFT JOIN (
+ SELECT
+ 1 AS id, 'b' AS x
+ UNION ALL
+ SELECT
+ 3 AS id, 'c' AS x
+ ) x
+ ON o.id = x.id
+ """,
+ tables=tables,
+ ).rows,
+ [(1, 1, "b"), (2, None, None)],
+ )
+ self.assertEqual(
+ execute(
+ """
+ SELECT
+ o.id, x.*
+ FROM orders o
+ RIGHT JOIN (
+ SELECT
+ 1 AS id,
+ 'b' AS x
+ UNION ALL
+ SELECT
+ 3 AS id, 'c' AS x
+ ) x
+ ON o.id = x.id
+ """,
+ tables=tables,
+ ).rows,
+ [
+ (1, 1, "b"),
+ (None, 3, "c"),
+ ],
+ )
+
+ def test_table_depth_mismatch(self):
+ tables = {"table": []}
+ schema = {"db": {"table": {"col": "VARCHAR"}}}
+ with self.assertRaises(ExecuteError):
+ execute("SELECT * FROM table", schema=schema, tables=tables)
+
+ def test_tables(self):
+ tables = ensure_tables(
+ {
+ "catalog1": {
+ "db1": {
+ "t1": [
+ {"a": 1},
+ ],
+ "t2": [
+ {"a": 1},
+ ],
+ },
+ "db2": {
+ "t3": [
+ {"a": 1},
+ ],
+ "t4": [
+ {"a": 1},
+ ],
+ },
+ },
+ "catalog2": {
+ "db3": {
+ "t5": Table(columns=("a",), rows=[(1,)]),
+ "t6": Table(columns=("a",), rows=[(1,)]),
+ },
+ "db4": {
+ "t7": Table(columns=("a",), rows=[(1,)]),
+ "t8": Table(columns=("a",), rows=[(1,)]),
+ },
+ },
+ }
+ )
+
+ t1 = tables.find(exp.table_(table="t1", db="db1", catalog="catalog1"))
+ self.assertEqual(t1.columns, ("a",))
+ self.assertEqual(t1.rows, [(1,)])
+
+ t8 = tables.find(exp.table_(table="t8"))
+ self.assertEqual(t1.columns, t8.columns)
+ self.assertEqual(t1.rows, t8.rows)
+
+ def test_static_queries(self):
+ for sql, cols, rows in [
+ ("SELECT 1", ["_col_0"], [(1,)]),
+ ("SELECT 1 + 2 AS x", ["x"], [(3,)]),
+ ("SELECT CONCAT('a', 'b') AS x", ["x"], [("ab",)]),
+ ("SELECT 1 AS x, 2 AS y", ["x", "y"], [(1, 2)]),
+ ("SELECT 'foo' LIMIT 1", ["_col_0"], [("foo",)]),
+ ]:
+ result = execute(sql)
+ self.assertEqual(result.columns, tuple(cols))
+ self.assertEqual(result.rows, rows)
+
+ def test_aggregate_without_group_by(self):
+ result = execute("SELECT SUM(x) FROM t", tables={"t": [{"x": 1}, {"x": 2}]})
+ self.assertEqual(result.columns, ("_col_0",))
+ self.assertEqual(result.rows, [(3,)])
+
+ def test_scalar_functions(self):
+ for sql, expected in [
+ ("CONCAT('a', 'b')", "ab"),
+ ("CONCAT('a', NULL)", None),
+ ("CONCAT_WS('_', 'a', 'b')", "a_b"),
+ ("STR_POSITION('bar', 'foobarbar')", 4),
+ ("STR_POSITION('bar', 'foobarbar', 5)", 7),
+ ("STR_POSITION(NULL, 'foobarbar')", None),
+ ("STR_POSITION('bar', NULL)", None),
+ ("UPPER('foo')", "FOO"),
+ ("UPPER(NULL)", None),
+ ("LOWER('FOO')", "foo"),
+ ("LOWER(NULL)", None),
+ ("IFNULL('a', 'b')", "a"),
+ ("IFNULL(NULL, 'b')", "b"),
+ ("IFNULL(NULL, NULL)", None),
+ ("SUBSTRING('12345')", "12345"),
+ ("SUBSTRING('12345', 3)", "345"),
+ ("SUBSTRING('12345', 3, 0)", ""),
+ ("SUBSTRING('12345', 3, 1)", "3"),
+ ("SUBSTRING('12345', 3, 2)", "34"),
+ ("SUBSTRING('12345', 3, 3)", "345"),
+ ("SUBSTRING('12345', 3, 4)", "345"),
+ ("SUBSTRING('12345', -3)", "345"),
+ ("SUBSTRING('12345', -3, 0)", ""),
+ ("SUBSTRING('12345', -3, 1)", "3"),
+ ("SUBSTRING('12345', -3, 2)", "34"),
+ ("SUBSTRING('12345', 0)", ""),
+ ("SUBSTRING('12345', 0, 1)", ""),
+ ("SUBSTRING(NULL)", None),
+ ("SUBSTRING(NULL, 1)", None),
+ ("CAST(1 AS TEXT)", "1"),
+ ("CAST('1' AS LONG)", 1),
+ ("CAST('1.1' AS FLOAT)", 1.1),
+ ("COALESCE(NULL)", None),
+ ("COALESCE(NULL, NULL)", None),
+ ("COALESCE(NULL, 'b')", "b"),
+ ("COALESCE('a', 'b')", "a"),
+ ("1 << 1", 2),
+ ("1 >> 1", 0),
+ ("1 & 1", 1),
+ ("1 | 1", 1),
+ ("1 < 1", False),
+ ("1 <= 1", True),
+ ("1 > 1", False),
+ ("1 >= 1", True),
+ ("1 + NULL", None),
+ ("IF(true, 1, 0)", 1),
+ ("IF(false, 1, 0)", 0),
+ ("CASE WHEN 0 = 1 THEN 'foo' ELSE 'bar' END", "bar"),
+ ("CAST('2022-01-01' AS DATE) + INTERVAL '1' DAY", date(2022, 1, 2)),
+ ]:
+ with self.subTest(sql):
+ result = execute(f"SELECT {sql}")
+ self.assertEqual(result.rows, [(expected,)])