summaryrefslogtreecommitdiffstats
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/dialects/test_bigquery.py4
-rw-r--r--tests/dialects/test_dialect.py2
-rw-r--r--tests/dialects/test_snowflake.py12
-rw-r--r--tests/fixtures/identity.sql9
-rw-r--r--tests/fixtures/optimizer/merge_subqueries.sql8
-rw-r--r--tests/fixtures/optimizer/optimizer.sql17
-rw-r--r--tests/fixtures/optimizer/pushdown_projections.sql12
-rw-r--r--tests/test_optimizer.py19
-rw-r--r--tests/test_parser.py3
9 files changed, 79 insertions, 7 deletions
diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py
index c929e59..7110eac 100644
--- a/tests/dialects/test_bigquery.py
+++ b/tests/dialects/test_bigquery.py
@@ -239,7 +239,7 @@ class TestBigQuery(Validator):
self.validate_all(
"SELECT cola, colb FROM (VALUES (1, 'test')) AS tab(cola, colb)",
write={
- "spark": "SELECT cola, colb FROM (VALUES (1, 'test')) AS tab(cola, colb)",
+ "spark": "SELECT cola, colb FROM VALUES (1, 'test') AS tab(cola, colb)",
"bigquery": "SELECT cola, colb FROM UNNEST([STRUCT(1 AS cola, 'test' AS colb)])",
"snowflake": "SELECT cola, colb FROM (VALUES (1, 'test')) AS tab(cola, colb)",
},
@@ -253,7 +253,7 @@ class TestBigQuery(Validator):
def test_user_defined_functions(self):
self.validate_identity(
- "CREATE TEMPORARY FUNCTION a(x FLOAT64, y FLOAT64) RETURNS FLOAT64 LANGUAGE js AS 'return x*y;'"
+ "CREATE TEMPORARY FUNCTION a(x FLOAT64, y FLOAT64) RETURNS FLOAT64 NOT DETERMINISTIC LANGUAGE js AS 'return x*y;'"
)
self.validate_identity("CREATE TEMPORARY FUNCTION a(x FLOAT64, y FLOAT64) AS ((x + 4) / y)")
self.validate_identity("CREATE TABLE FUNCTION a(x INT64) RETURNS TABLE <q STRING, r INT64> AS SELECT s, t")
diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py
index e0ec824..a9a313c 100644
--- a/tests/dialects/test_dialect.py
+++ b/tests/dialects/test_dialect.py
@@ -1009,7 +1009,7 @@ class TestDialect(Validator):
self.validate_all(
"SELECT * FROM VALUES ('x'), ('y') AS t(z)",
write={
- "spark": "SELECT * FROM (VALUES ('x'), ('y')) AS t(z)",
+ "spark": "SELECT * FROM VALUES ('x'), ('y') AS t(z)",
},
)
self.validate_all(
diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py
index b7e39a7..2145966 100644
--- a/tests/dialects/test_snowflake.py
+++ b/tests/dialects/test_snowflake.py
@@ -293,3 +293,15 @@ class TestSnowflake(Validator):
"bigquery": "CREATE TABLE FUNCTION a() RETURNS TABLE <b INT64> AS SELECT 1",
},
)
+ self.validate_all(
+ "CREATE FUNCTION a() RETURNS INT IMMUTABLE AS 'SELECT 1'",
+ write={
+ "snowflake": "CREATE FUNCTION a() RETURNS INT IMMUTABLE AS 'SELECT 1'",
+ },
+ )
+
+ def test_stored_procedures(self):
+ self.validate_identity("CALL a.b.c(x, y)")
+ self.validate_identity(
+ "CREATE PROCEDURE a.b.c(x INT, y VARIANT) RETURNS OBJECT EXECUTE AS CALLER AS 'BEGIN SELECT 1; END;'"
+ )
diff --git a/tests/fixtures/identity.sql b/tests/fixtures/identity.sql
index 2654be1..a0de281 100644
--- a/tests/fixtures/identity.sql
+++ b/tests/fixtures/identity.sql
@@ -50,6 +50,7 @@ a.B()
a['x'].C()
int.x
map.x
+a.b.INT(1.234)
x IN (-1, 1)
x IN ('a', 'a''a')
x IN ((1))
@@ -357,6 +358,7 @@ SELECT * REPLACE (a + 1 AS b, b AS C)
SELECT * EXCEPT (a, b) REPLACE (a AS b, b AS C)
SELECT a.* EXCEPT (a, b), b.* REPLACE (a AS b, b AS C)
SELECT zoo, animals FROM (VALUES ('oakland', ARRAY('a', 'b')), ('sf', ARRAY('b', 'c'))) AS t(zoo, animals)
+SELECT zoo, animals FROM UNNEST(ARRAY(STRUCT('oakland' AS zoo, ARRAY('a', 'b') AS animals), STRUCT('sf' AS zoo, ARRAY('b', 'c') AS animals))) AS t(zoo, animals)
WITH a AS (SELECT 1) SELECT 1 UNION ALL SELECT 2
WITH a AS (SELECT 1) SELECT 1 UNION SELECT 2
WITH a AS (SELECT 1) SELECT 1 INTERSECT SELECT 2
@@ -444,6 +446,8 @@ CREATE OR REPLACE TEMPORARY VIEW x AS SELECT *
CREATE TEMPORARY VIEW x AS SELECT a FROM d
CREATE TEMPORARY VIEW IF NOT EXISTS x AS SELECT a FROM d
CREATE TEMPORARY VIEW x AS WITH y AS (SELECT 1) SELECT * FROM y
+CREATE MATERIALIZED VIEW x.y.z AS SELECT a FROM b
+DROP MATERIALIZED VIEW x.y.z
CREATE TABLE z (a INT, b VARCHAR, c VARCHAR(100), d DECIMAL(5, 3))
CREATE TABLE z (end INT)
CREATE TABLE z (a ARRAY<TEXT>, b MAP<TEXT, DOUBLE>, c DECIMAL(5, 3))
@@ -471,10 +475,13 @@ CREATE FUNCTION f AS 'g'
CREATE FUNCTION a(b INT, c VARCHAR) AS 'SELECT 1'
CREATE FUNCTION a() LANGUAGE sql
CREATE FUNCTION a() LANGUAGE sql RETURNS INT
+CREATE FUNCTION a.b.c()
+DROP FUNCTION a.b.c (INT)
CREATE INDEX abc ON t (a)
CREATE INDEX abc ON t (a, b, b)
CREATE UNIQUE INDEX abc ON t (a, b, b)
CREATE UNIQUE INDEX IF NOT EXISTS my_idx ON tbl (a, b)
+DROP INDEX a.b.c
CACHE TABLE x
CACHE LAZY TABLE x
CACHE LAZY TABLE x OPTIONS('storageLevel' = 'value')
@@ -484,6 +491,8 @@ CACHE LAZY TABLE x AS WITH a AS (SELECT 1) SELECT a.* FROM a
CACHE TABLE x AS WITH a AS (SELECT 1) SELECT a.* FROM a
CACHE TABLE x AS (SELECT 1 AS y)
CALL catalog.system.iceberg_procedure_name(named_arg_1 => 'arg_1', named_arg_2 => 'arg_2')
+CREATE PROCEDURE IF NOT EXISTS a.b.c() AS 'DECLARE BEGIN; END'
+DROP PROCEDURE a.b.c (INT)
INSERT OVERWRITE TABLE a.b PARTITION(ds) SELECT x FROM y
INSERT OVERWRITE TABLE a.b PARTITION(ds='YYYY-MM-DD') SELECT x FROM y
INSERT OVERWRITE TABLE a.b PARTITION(ds, hour) SELECT x FROM y
diff --git a/tests/fixtures/optimizer/merge_subqueries.sql b/tests/fixtures/optimizer/merge_subqueries.sql
index 35aed3b..e13d3b3 100644
--- a/tests/fixtures/optimizer/merge_subqueries.sql
+++ b/tests/fixtures/optimizer/merge_subqueries.sql
@@ -97,3 +97,11 @@ WITH x AS (SELECT x.a AS a, x.b AS b FROM x AS x) SELECT x.a AS a, y.b AS b FROM
-- Nested CTE
SELECT * FROM (WITH x AS (SELECT a, b FROM x) SELECT a, b FROM x);
SELECT x.a AS a, x.b AS b FROM x AS x;
+
+-- Inner select is an expression
+SELECT a FROM (SELECT a FROM (SELECT COALESCE(a) AS a FROM x LEFT JOIN y ON x.a = y.b) AS x) AS x;
+SELECT COALESCE(x.a) AS a FROM x AS x LEFT JOIN y AS y ON x.a = y.b;
+
+-- CTE select is an expression
+WITH x AS (SELECT COALESCE(a) AS a FROM x LEFT JOIN y ON x.a = y.b) SELECT a FROM (SELECT a FROM x AS x) AS x;
+SELECT COALESCE(x.a) AS a FROM x AS x LEFT JOIN y AS y ON x.a = y.b;
diff --git a/tests/fixtures/optimizer/optimizer.sql b/tests/fixtures/optimizer/optimizer.sql
index 0bb742b..eb6761a 100644
--- a/tests/fixtures/optimizer/optimizer.sql
+++ b/tests/fixtures/optimizer/optimizer.sql
@@ -137,3 +137,20 @@ SELECT AGGREGATE(ARRAY(x.a, x.b), 0, (x, acc) -> x + acc + a) AS sum_agg FROM x;
SELECT
AGGREGATE(ARRAY("x"."a", "x"."b"), 0, ("x", "acc") -> "x" + "acc" + "x"."a") AS "sum_agg"
FROM "x" AS "x";
+
+SELECT cola, colb FROM (VALUES (1, 'test'), (2, 'test2')) AS tab(cola, colb);
+SELECT
+ "tab"."cola" AS "cola",
+ "tab"."colb" AS "colb"
+FROM (VALUES
+ (1, 'test'),
+ (2, 'test2')) AS "tab"("cola", "colb");
+
+# dialect: spark
+SELECT cola, colb FROM (VALUES (1, 'test'), (2, 'test2')) AS tab(cola, colb);
+SELECT
+ `tab`.`cola` AS `cola`,
+ `tab`.`colb` AS `colb`
+FROM VALUES
+ (1, 'test'),
+ (2, 'test2') AS `tab`(`cola`, `colb`);
diff --git a/tests/fixtures/optimizer/pushdown_projections.sql b/tests/fixtures/optimizer/pushdown_projections.sql
index 9deceb6..b03ffab 100644
--- a/tests/fixtures/optimizer/pushdown_projections.sql
+++ b/tests/fixtures/optimizer/pushdown_projections.sql
@@ -39,3 +39,15 @@ SELECT "_q_0".b AS b FROM (SELECT SUM(x.b) AS b FROM x AS x GROUP BY x.a) AS "_q
SELECT b FROM (SELECT a, SUM(b) AS b FROM x ORDER BY a);
SELECT "_q_0".b AS b FROM (SELECT x.a AS a, SUM(x.b) AS b FROM x AS x ORDER BY a) AS "_q_0";
+
+SELECT x FROM (VALUES(1, 2)) AS q(x, y);
+SELECT q.x AS x FROM (VALUES (1, 2)) AS q(x, y);
+
+SELECT x FROM UNNEST([1, 2]) AS q(x, y);
+SELECT q.x AS x FROM UNNEST(ARRAY(1, 2)) AS q(x, y);
+
+WITH t1 AS (SELECT cola, colb FROM UNNEST([STRUCT(1 AS cola, 'test' AS colb)]) AS "q"("cola", "colb")) SELECT cola FROM t1;
+WITH t1 AS (SELECT q.cola AS cola FROM UNNEST(ARRAY(STRUCT(1 AS cola, 'test' AS colb))) AS "q"("cola", "colb")) SELECT t1.cola AS cola FROM t1;
+
+SELECT x FROM VALUES(1, 2) AS q(x, y);
+SELECT q.x AS x FROM (VALUES (1, 2)) AS q(x, y);
diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py
index 8d4aecc..aad84ed 100644
--- a/tests/test_optimizer.py
+++ b/tests/test_optimizer.py
@@ -5,7 +5,7 @@ from sqlglot import exp, optimizer, parse_one, table
from sqlglot.errors import OptimizeError
from sqlglot.optimizer.annotate_types import annotate_types
from sqlglot.optimizer.schema import MappingSchema, ensure_schema
-from sqlglot.optimizer.scope import build_scope, traverse_scope
+from sqlglot.optimizer.scope import build_scope, traverse_scope, walk_in_scope
from tests.helpers import TPCH_SCHEMA, load_sql_fixture_pairs, load_sql_fixtures
@@ -264,12 +264,13 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
ON s.b = r.b
WHERE s.b > (SELECT MAX(x.a) FROM x WHERE x.b = s.b)
"""
- for scopes in traverse_scope(parse_one(sql)), list(build_scope(parse_one(sql)).traverse()):
+ expression = parse_one(sql)
+ for scopes in traverse_scope(expression), list(build_scope(expression).traverse()):
self.assertEqual(len(scopes), 5)
self.assertEqual(scopes[0].expression.sql(), "SELECT x.b FROM x")
self.assertEqual(scopes[1].expression.sql(), "SELECT y.b FROM y")
- self.assertEqual(scopes[2].expression.sql(), "SELECT MAX(x.a) FROM x WHERE x.b = s.b")
- self.assertEqual(scopes[3].expression.sql(), "SELECT y.c AS b FROM y")
+ self.assertEqual(scopes[2].expression.sql(), "SELECT y.c AS b FROM y")
+ self.assertEqual(scopes[3].expression.sql(), "SELECT MAX(x.a) FROM x WHERE x.b = s.b")
self.assertEqual(scopes[4].expression.sql(), parse_one(sql).sql())
self.assertEqual(set(scopes[4].sources), {"q", "r", "s"})
@@ -279,6 +280,16 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
self.assertEqual(len(scopes[4].source_columns("r")), 2)
self.assertEqual(set(c.table for c in scopes[4].source_columns("r")), {"r"})
+ self.assertEqual({c.sql() for c in scopes[-1].find_all(exp.Column)}, {"r.b", "s.b"})
+ self.assertEqual(scopes[-1].find(exp.Column).sql(), "r.b")
+ self.assertEqual({c.sql() for c in scopes[0].find_all(exp.Column)}, {"x.b"})
+
+ # Check that we can walk in scope from an arbitrary node
+ self.assertEqual(
+ {node.sql() for node, *_ in walk_in_scope(expression.find(exp.Where)) if isinstance(node, exp.Column)},
+ {"s.b"},
+ )
+
def test_literal_type_annotation(self):
tests = {
"SELECT 5": exp.DataType.Type.INT,
diff --git a/tests/test_parser.py b/tests/test_parser.py
index 4c46531..4e86516 100644
--- a/tests/test_parser.py
+++ b/tests/test_parser.py
@@ -122,6 +122,9 @@ class TestParser(unittest.TestCase):
def test_parameter(self):
self.assertEqual(parse_one("SELECT @x, @@x, @1").sql(), "SELECT @x, @@x, @1")
+ def test_var(self):
+ self.assertEqual(parse_one("SELECT @JOIN, @'foo'").sql(), "SELECT @JOIN, @'foo'")
+
def test_annotations(self):
expression = parse_one(
"""