diff options
Diffstat (limited to '')
-rw-r--r-- | tests/dialects/test_bigquery.py | 4 | ||||
-rw-r--r-- | tests/dialects/test_dialect.py | 2 | ||||
-rw-r--r-- | tests/dialects/test_snowflake.py | 12 | ||||
-rw-r--r-- | tests/fixtures/identity.sql | 9 | ||||
-rw-r--r-- | tests/fixtures/optimizer/merge_subqueries.sql | 8 | ||||
-rw-r--r-- | tests/fixtures/optimizer/optimizer.sql | 17 | ||||
-rw-r--r-- | tests/fixtures/optimizer/pushdown_projections.sql | 12 | ||||
-rw-r--r-- | tests/test_optimizer.py | 19 | ||||
-rw-r--r-- | tests/test_parser.py | 3 |
9 files changed, 79 insertions, 7 deletions
diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py index c929e59..7110eac 100644 --- a/tests/dialects/test_bigquery.py +++ b/tests/dialects/test_bigquery.py @@ -239,7 +239,7 @@ class TestBigQuery(Validator): self.validate_all( "SELECT cola, colb FROM (VALUES (1, 'test')) AS tab(cola, colb)", write={ - "spark": "SELECT cola, colb FROM (VALUES (1, 'test')) AS tab(cola, colb)", + "spark": "SELECT cola, colb FROM VALUES (1, 'test') AS tab(cola, colb)", "bigquery": "SELECT cola, colb FROM UNNEST([STRUCT(1 AS cola, 'test' AS colb)])", "snowflake": "SELECT cola, colb FROM (VALUES (1, 'test')) AS tab(cola, colb)", }, @@ -253,7 +253,7 @@ class TestBigQuery(Validator): def test_user_defined_functions(self): self.validate_identity( - "CREATE TEMPORARY FUNCTION a(x FLOAT64, y FLOAT64) RETURNS FLOAT64 LANGUAGE js AS 'return x*y;'" + "CREATE TEMPORARY FUNCTION a(x FLOAT64, y FLOAT64) RETURNS FLOAT64 NOT DETERMINISTIC LANGUAGE js AS 'return x*y;'" ) self.validate_identity("CREATE TEMPORARY FUNCTION a(x FLOAT64, y FLOAT64) AS ((x + 4) / y)") self.validate_identity("CREATE TABLE FUNCTION a(x INT64) RETURNS TABLE <q STRING, r INT64> AS SELECT s, t") diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py index e0ec824..a9a313c 100644 --- a/tests/dialects/test_dialect.py +++ b/tests/dialects/test_dialect.py @@ -1009,7 +1009,7 @@ class TestDialect(Validator): self.validate_all( "SELECT * FROM VALUES ('x'), ('y') AS t(z)", write={ - "spark": "SELECT * FROM (VALUES ('x'), ('y')) AS t(z)", + "spark": "SELECT * FROM VALUES ('x'), ('y') AS t(z)", }, ) self.validate_all( diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py index b7e39a7..2145966 100644 --- a/tests/dialects/test_snowflake.py +++ b/tests/dialects/test_snowflake.py @@ -293,3 +293,15 @@ class TestSnowflake(Validator): "bigquery": "CREATE TABLE FUNCTION a() RETURNS TABLE <b INT64> AS SELECT 1", }, ) + self.validate_all( + "CREATE FUNCTION a() RETURNS INT IMMUTABLE AS 'SELECT 1'", + write={ + "snowflake": "CREATE FUNCTION a() RETURNS INT IMMUTABLE AS 'SELECT 1'", + }, + ) + + def test_stored_procedures(self): + self.validate_identity("CALL a.b.c(x, y)") + self.validate_identity( + "CREATE PROCEDURE a.b.c(x INT, y VARIANT) RETURNS OBJECT EXECUTE AS CALLER AS 'BEGIN SELECT 1; END;'" + ) diff --git a/tests/fixtures/identity.sql b/tests/fixtures/identity.sql index 2654be1..a0de281 100644 --- a/tests/fixtures/identity.sql +++ b/tests/fixtures/identity.sql @@ -50,6 +50,7 @@ a.B() a['x'].C() int.x map.x +a.b.INT(1.234) x IN (-1, 1) x IN ('a', 'a''a') x IN ((1)) @@ -357,6 +358,7 @@ SELECT * REPLACE (a + 1 AS b, b AS C) SELECT * EXCEPT (a, b) REPLACE (a AS b, b AS C) SELECT a.* EXCEPT (a, b), b.* REPLACE (a AS b, b AS C) SELECT zoo, animals FROM (VALUES ('oakland', ARRAY('a', 'b')), ('sf', ARRAY('b', 'c'))) AS t(zoo, animals) +SELECT zoo, animals FROM UNNEST(ARRAY(STRUCT('oakland' AS zoo, ARRAY('a', 'b') AS animals), STRUCT('sf' AS zoo, ARRAY('b', 'c') AS animals))) AS t(zoo, animals) WITH a AS (SELECT 1) SELECT 1 UNION ALL SELECT 2 WITH a AS (SELECT 1) SELECT 1 UNION SELECT 2 WITH a AS (SELECT 1) SELECT 1 INTERSECT SELECT 2 @@ -444,6 +446,8 @@ CREATE OR REPLACE TEMPORARY VIEW x AS SELECT * CREATE TEMPORARY VIEW x AS SELECT a FROM d CREATE TEMPORARY VIEW IF NOT EXISTS x AS SELECT a FROM d CREATE TEMPORARY VIEW x AS WITH y AS (SELECT 1) SELECT * FROM y +CREATE MATERIALIZED VIEW x.y.z AS SELECT a FROM b +DROP MATERIALIZED VIEW x.y.z CREATE TABLE z (a INT, b VARCHAR, c VARCHAR(100), d DECIMAL(5, 3)) CREATE TABLE z (end INT) CREATE TABLE z (a ARRAY<TEXT>, b MAP<TEXT, DOUBLE>, c DECIMAL(5, 3)) @@ -471,10 +475,13 @@ CREATE FUNCTION f AS 'g' CREATE FUNCTION a(b INT, c VARCHAR) AS 'SELECT 1' CREATE FUNCTION a() LANGUAGE sql CREATE FUNCTION a() LANGUAGE sql RETURNS INT +CREATE FUNCTION a.b.c() +DROP FUNCTION a.b.c (INT) CREATE INDEX abc ON t (a) CREATE INDEX abc ON t (a, b, b) CREATE UNIQUE INDEX abc ON t (a, b, b) CREATE UNIQUE INDEX IF NOT EXISTS my_idx ON tbl (a, b) +DROP INDEX a.b.c CACHE TABLE x CACHE LAZY TABLE x CACHE LAZY TABLE x OPTIONS('storageLevel' = 'value') @@ -484,6 +491,8 @@ CACHE LAZY TABLE x AS WITH a AS (SELECT 1) SELECT a.* FROM a CACHE TABLE x AS WITH a AS (SELECT 1) SELECT a.* FROM a CACHE TABLE x AS (SELECT 1 AS y) CALL catalog.system.iceberg_procedure_name(named_arg_1 => 'arg_1', named_arg_2 => 'arg_2') +CREATE PROCEDURE IF NOT EXISTS a.b.c() AS 'DECLARE BEGIN; END' +DROP PROCEDURE a.b.c (INT) INSERT OVERWRITE TABLE a.b PARTITION(ds) SELECT x FROM y INSERT OVERWRITE TABLE a.b PARTITION(ds='YYYY-MM-DD') SELECT x FROM y INSERT OVERWRITE TABLE a.b PARTITION(ds, hour) SELECT x FROM y diff --git a/tests/fixtures/optimizer/merge_subqueries.sql b/tests/fixtures/optimizer/merge_subqueries.sql index 35aed3b..e13d3b3 100644 --- a/tests/fixtures/optimizer/merge_subqueries.sql +++ b/tests/fixtures/optimizer/merge_subqueries.sql @@ -97,3 +97,11 @@ WITH x AS (SELECT x.a AS a, x.b AS b FROM x AS x) SELECT x.a AS a, y.b AS b FROM -- Nested CTE SELECT * FROM (WITH x AS (SELECT a, b FROM x) SELECT a, b FROM x); SELECT x.a AS a, x.b AS b FROM x AS x; + +-- Inner select is an expression +SELECT a FROM (SELECT a FROM (SELECT COALESCE(a) AS a FROM x LEFT JOIN y ON x.a = y.b) AS x) AS x; +SELECT COALESCE(x.a) AS a FROM x AS x LEFT JOIN y AS y ON x.a = y.b; + +-- CTE select is an expression +WITH x AS (SELECT COALESCE(a) AS a FROM x LEFT JOIN y ON x.a = y.b) SELECT a FROM (SELECT a FROM x AS x) AS x; +SELECT COALESCE(x.a) AS a FROM x AS x LEFT JOIN y AS y ON x.a = y.b; diff --git a/tests/fixtures/optimizer/optimizer.sql b/tests/fixtures/optimizer/optimizer.sql index 0bb742b..eb6761a 100644 --- a/tests/fixtures/optimizer/optimizer.sql +++ b/tests/fixtures/optimizer/optimizer.sql @@ -137,3 +137,20 @@ SELECT AGGREGATE(ARRAY(x.a, x.b), 0, (x, acc) -> x + acc + a) AS sum_agg FROM x; SELECT AGGREGATE(ARRAY("x"."a", "x"."b"), 0, ("x", "acc") -> "x" + "acc" + "x"."a") AS "sum_agg" FROM "x" AS "x"; + +SELECT cola, colb FROM (VALUES (1, 'test'), (2, 'test2')) AS tab(cola, colb); +SELECT + "tab"."cola" AS "cola", + "tab"."colb" AS "colb" +FROM (VALUES + (1, 'test'), + (2, 'test2')) AS "tab"("cola", "colb"); + +# dialect: spark +SELECT cola, colb FROM (VALUES (1, 'test'), (2, 'test2')) AS tab(cola, colb); +SELECT + `tab`.`cola` AS `cola`, + `tab`.`colb` AS `colb` +FROM VALUES + (1, 'test'), + (2, 'test2') AS `tab`(`cola`, `colb`); diff --git a/tests/fixtures/optimizer/pushdown_projections.sql b/tests/fixtures/optimizer/pushdown_projections.sql index 9deceb6..b03ffab 100644 --- a/tests/fixtures/optimizer/pushdown_projections.sql +++ b/tests/fixtures/optimizer/pushdown_projections.sql @@ -39,3 +39,15 @@ SELECT "_q_0".b AS b FROM (SELECT SUM(x.b) AS b FROM x AS x GROUP BY x.a) AS "_q SELECT b FROM (SELECT a, SUM(b) AS b FROM x ORDER BY a); SELECT "_q_0".b AS b FROM (SELECT x.a AS a, SUM(x.b) AS b FROM x AS x ORDER BY a) AS "_q_0"; + +SELECT x FROM (VALUES(1, 2)) AS q(x, y); +SELECT q.x AS x FROM (VALUES (1, 2)) AS q(x, y); + +SELECT x FROM UNNEST([1, 2]) AS q(x, y); +SELECT q.x AS x FROM UNNEST(ARRAY(1, 2)) AS q(x, y); + +WITH t1 AS (SELECT cola, colb FROM UNNEST([STRUCT(1 AS cola, 'test' AS colb)]) AS "q"("cola", "colb")) SELECT cola FROM t1; +WITH t1 AS (SELECT q.cola AS cola FROM UNNEST(ARRAY(STRUCT(1 AS cola, 'test' AS colb))) AS "q"("cola", "colb")) SELECT t1.cola AS cola FROM t1; + +SELECT x FROM VALUES(1, 2) AS q(x, y); +SELECT q.x AS x FROM (VALUES (1, 2)) AS q(x, y); diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index 8d4aecc..aad84ed 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -5,7 +5,7 @@ from sqlglot import exp, optimizer, parse_one, table from sqlglot.errors import OptimizeError from sqlglot.optimizer.annotate_types import annotate_types from sqlglot.optimizer.schema import MappingSchema, ensure_schema -from sqlglot.optimizer.scope import build_scope, traverse_scope +from sqlglot.optimizer.scope import build_scope, traverse_scope, walk_in_scope from tests.helpers import TPCH_SCHEMA, load_sql_fixture_pairs, load_sql_fixtures @@ -264,12 +264,13 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') ON s.b = r.b WHERE s.b > (SELECT MAX(x.a) FROM x WHERE x.b = s.b) """ - for scopes in traverse_scope(parse_one(sql)), list(build_scope(parse_one(sql)).traverse()): + expression = parse_one(sql) + for scopes in traverse_scope(expression), list(build_scope(expression).traverse()): self.assertEqual(len(scopes), 5) self.assertEqual(scopes[0].expression.sql(), "SELECT x.b FROM x") self.assertEqual(scopes[1].expression.sql(), "SELECT y.b FROM y") - self.assertEqual(scopes[2].expression.sql(), "SELECT MAX(x.a) FROM x WHERE x.b = s.b") - self.assertEqual(scopes[3].expression.sql(), "SELECT y.c AS b FROM y") + self.assertEqual(scopes[2].expression.sql(), "SELECT y.c AS b FROM y") + self.assertEqual(scopes[3].expression.sql(), "SELECT MAX(x.a) FROM x WHERE x.b = s.b") self.assertEqual(scopes[4].expression.sql(), parse_one(sql).sql()) self.assertEqual(set(scopes[4].sources), {"q", "r", "s"}) @@ -279,6 +280,16 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') self.assertEqual(len(scopes[4].source_columns("r")), 2) self.assertEqual(set(c.table for c in scopes[4].source_columns("r")), {"r"}) + self.assertEqual({c.sql() for c in scopes[-1].find_all(exp.Column)}, {"r.b", "s.b"}) + self.assertEqual(scopes[-1].find(exp.Column).sql(), "r.b") + self.assertEqual({c.sql() for c in scopes[0].find_all(exp.Column)}, {"x.b"}) + + # Check that we can walk in scope from an arbitrary node + self.assertEqual( + {node.sql() for node, *_ in walk_in_scope(expression.find(exp.Where)) if isinstance(node, exp.Column)}, + {"s.b"}, + ) + def test_literal_type_annotation(self): tests = { "SELECT 5": exp.DataType.Type.INT, diff --git a/tests/test_parser.py b/tests/test_parser.py index 4c46531..4e86516 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -122,6 +122,9 @@ class TestParser(unittest.TestCase): def test_parameter(self): self.assertEqual(parse_one("SELECT @x, @@x, @1").sql(), "SELECT @x, @@x, @1") + def test_var(self): + self.assertEqual(parse_one("SELECT @JOIN, @'foo'").sql(), "SELECT @JOIN, @'foo'") + def test_annotations(self): expression = parse_one( """ |