diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-07-06 07:28:09 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-07-06 07:28:09 +0000 |
commit | 52f4a5e2260f3e5b919b4e270339afd670bf0b8a (patch) | |
tree | 5ca419af0e2e409018492b82f5b9847f0112b5fb /tests | |
parent | Adding upstream version 16.7.7. (diff) | |
download | sqlglot-52f4a5e2260f3e5b919b4e270339afd670bf0b8a.tar.xz sqlglot-52f4a5e2260f3e5b919b4e270339afd670bf0b8a.zip |
Adding upstream version 17.2.0.upstream/17.2.0
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to '')
-rw-r--r-- | tests/dialects/test_bigquery.py | 34 | ||||
-rw-r--r-- | tests/dialects/test_dialect.py | 4 | ||||
-rw-r--r-- | tests/dialects/test_duckdb.py | 17 | ||||
-rw-r--r-- | tests/dialects/test_mysql.py | 45 | ||||
-rw-r--r-- | tests/dialects/test_postgres.py | 2 | ||||
-rw-r--r-- | tests/dialects/test_presto.py | 2 | ||||
-rw-r--r-- | tests/dialects/test_snowflake.py | 7 | ||||
-rw-r--r-- | tests/dialects/test_spark.py | 5 | ||||
-rw-r--r-- | tests/dialects/test_sqlite.py | 9 | ||||
-rw-r--r-- | tests/dialects/test_tsql.py | 31 | ||||
-rw-r--r-- | tests/fixtures/identity.sql | 30 | ||||
-rw-r--r-- | tests/fixtures/pretty.sql | 20 | ||||
-rw-r--r-- | tests/test_executor.py | 23 | ||||
-rw-r--r-- | tests/test_expressions.py | 10 | ||||
-rw-r--r-- | tests/test_optimizer.py | 27 | ||||
-rw-r--r-- | tests/test_parser.py | 53 | ||||
-rw-r--r-- | tests/test_transpile.py | 6 |
17 files changed, 274 insertions, 51 deletions
diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py index e3fc495..eac3cac 100644 --- a/tests/dialects/test_bigquery.py +++ b/tests/dialects/test_bigquery.py @@ -29,12 +29,14 @@ class TestBigQuery(Validator): with self.assertRaises(ParseError): transpile("SELECT * FROM UNNEST(x) AS x(y)", read="bigquery") + self.validate_identity("SELECT PARSE_TIMESTAMP('%c', 'Thu Dec 25 07:30:00 2008', 'UTC')") self.validate_identity("SELECT ANY_VALUE(fruit HAVING MAX sold) FROM fruits") self.validate_identity("SELECT ANY_VALUE(fruit HAVING MIN sold) FROM fruits") self.validate_identity("SELECT `project-id`.udfs.func(call.dir)") self.validate_identity("SELECT CAST(CURRENT_DATE AS STRING FORMAT 'DAY') AS current_day") self.validate_identity("SAFE_CAST(encrypted_value AS STRING FORMAT 'BASE64')") self.validate_identity("CAST(encrypted_value AS STRING FORMAT 'BASE64')") + self.validate_identity("CAST(STRUCT<a INT64>(1) AS STRUCT<a INT64>)") self.validate_identity("STRING_AGG(a)") self.validate_identity("STRING_AGG(a, ' & ')") self.validate_identity("STRING_AGG(DISTINCT a, ' & ')") @@ -106,6 +108,14 @@ class TestBigQuery(Validator): self.validate_all("CAST(x AS TIMESTAMPTZ)", write={"bigquery": "CAST(x AS TIMESTAMP)"}) self.validate_all("CAST(x AS RECORD)", write={"bigquery": "CAST(x AS STRUCT)"}) self.validate_all( + "SELECT CAST('20201225' AS TIMESTAMP FORMAT 'YYYYMMDD' AT TIME ZONE 'America/New_York')", + write={"bigquery": "SELECT PARSE_TIMESTAMP('%Y%m%d', '20201225', 'America/New_York')"}, + ) + self.validate_all( + "SELECT CAST('20201225' AS TIMESTAMP FORMAT 'YYYYMMDD')", + write={"bigquery": "SELECT PARSE_TIMESTAMP('%Y%m%d', '20201225')"}, + ) + self.validate_all( "SELECT CAST(TIMESTAMP '2008-12-25 00:00:00+00:00' AS STRING FORMAT 'YYYY-MM-DD HH24:MI:SS TZH:TZM') AS date_time_to_string", write={ "bigquery": "SELECT CAST(CAST('2008-12-25 00:00:00+00:00' AS TIMESTAMP) AS STRING FORMAT 'YYYY-MM-DD HH24:MI:SS TZH:TZM') AS date_time_to_string", @@ -191,7 +201,7 @@ class TestBigQuery(Validator): self.validate_all( "r'x\\''", write={ - "bigquery": "r'x\\''", + "bigquery": "'x\\''", "hive": "'x\\''", }, ) @@ -199,7 +209,7 @@ class TestBigQuery(Validator): self.validate_all( "r'x\\y'", write={ - "bigquery": "r'x\\y'", + "bigquery": "'x\\\y'", "hive": "'x\\\\y'", }, ) @@ -215,7 +225,7 @@ class TestBigQuery(Validator): self.validate_all( r'r"""/\*.*\*/"""', write={ - "bigquery": r"r'/\*.*\*/'", + "bigquery": r"'/\\*.*\\*/'", "duckdb": r"'/\\*.*\\*/'", "presto": r"'/\\*.*\\*/'", "hive": r"'/\\*.*\\*/'", @@ -225,7 +235,7 @@ class TestBigQuery(Validator): self.validate_all( r'R"""/\*.*\*/"""', write={ - "bigquery": r"r'/\*.*\*/'", + "bigquery": r"'/\\*.*\\*/'", "duckdb": r"'/\\*.*\\*/'", "presto": r"'/\\*.*\\*/'", "hive": r"'/\\*.*\\*/'", @@ -233,6 +243,20 @@ class TestBigQuery(Validator): }, ) self.validate_all( + 'r"""a\n"""', + write={ + "bigquery": "'a\\n'", + "duckdb": "'a\n'", + }, + ) + self.validate_all( + '"""a\n"""', + write={ + "bigquery": "'a\\n'", + "duckdb": "'a\n'", + }, + ) + self.validate_all( "CAST(a AS INT64)", write={ "bigquery": "CAST(a AS INT64)", @@ -603,7 +627,7 @@ class TestBigQuery(Validator): ) @mock.patch("sqlglot.dialects.bigquery.logger") - def test_pushdown_cte_column_names(self, mock_logger): + def test_pushdown_cte_column_names(self, logger): with self.assertRaises(UnsupportedError): transpile( "WITH cte(foo) AS (SELECT * FROM tbl) SELECT foo FROM cte", diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py index 78f87ff..21efc6b 100644 --- a/tests/dialects/test_dialect.py +++ b/tests/dialects/test_dialect.py @@ -601,7 +601,6 @@ class TestDialect(Validator): "DATE_TRUNC('day', x)", read={ "bigquery": "DATE_TRUNC(x, day)", - "duckdb": "DATE_TRUNC('day', x)", "spark": "TRUNC(x, 'day')", }, write={ @@ -619,6 +618,7 @@ class TestDialect(Validator): "TIMESTAMP_TRUNC(x, day)", read={ "bigquery": "TIMESTAMP_TRUNC(x, day)", + "duckdb": "DATE_TRUNC('day', x)", "presto": "DATE_TRUNC('day', x)", "postgres": "DATE_TRUNC('day', x)", "snowflake": "DATE_TRUNC('day', x)", @@ -1307,7 +1307,7 @@ class TestDialect(Validator): write={ "sqlite": "SELECT x FROM y LIMIT 10", "oracle": "SELECT x FROM y FETCH FIRST 10 ROWS ONLY", - "tsql": "SELECT x FROM y FETCH FIRST 10 ROWS ONLY", + "tsql": "SELECT TOP 10 x FROM y", }, ) self.validate_all( diff --git a/tests/dialects/test_duckdb.py b/tests/dialects/test_duckdb.py index 4065f81..cad1c15 100644 --- a/tests/dialects/test_duckdb.py +++ b/tests/dialects/test_duckdb.py @@ -10,6 +10,23 @@ class TestDuckDB(Validator): self.validate_identity("SELECT CURRENT_TIMESTAMP") self.validate_all( + "SELECT MAKE_DATE(2016, 12, 25)", read={"bigquery": "SELECT DATE(2016, 12, 25)"} + ) + self.validate_all( + "SELECT CAST(CAST('2016-12-25 23:59:59' AS DATETIME) AS DATE)", + read={"bigquery": "SELECT DATE(DATETIME '2016-12-25 23:59:59')"}, + ) + self.validate_all( + "SELECT STRPTIME(STRFTIME(CAST(CAST('2016-12-25' AS TIMESTAMPTZ) AS DATE), '%d/%m/%Y') || ' ' || 'America/Los_Angeles', '%d/%m/%Y %Z')", + read={ + "bigquery": "SELECT DATE(TIMESTAMP '2016-12-25', 'America/Los_Angeles')", + }, + ) + self.validate_all( + "SELECT CAST(CAST(STRPTIME('05/06/2020', '%m/%d/%Y') AS DATE) AS DATE)", + read={"bigquery": "SELECT DATE(PARSE_DATE('%m/%d/%Y', '05/06/2020'))"}, + ) + self.validate_all( "SELECT CAST('2020-01-01' AS DATE) + INTERVAL (-1) DAY", read={"mysql": "SELECT DATE '2020-01-01' + INTERVAL -1 DAY"}, ) diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py index ca2f921..3539ad0 100644 --- a/tests/dialects/test_mysql.py +++ b/tests/dialects/test_mysql.py @@ -10,7 +10,28 @@ class TestMySQL(Validator): self.validate_identity("UPDATE items SET items.price = 0 WHERE items.id >= 5 LIMIT 10") self.validate_identity("DELETE FROM t WHERE a <= 10 LIMIT 10") self.validate_identity( - "INSERT INTO x VALUES (1, 'a', 2.0) ON DUPLICATE KEY UPDATE SET x.id = 1" + "DELETE t1 FROM t1 LEFT JOIN t2 ON t1.id = t2.id WHERE t2.id IS NULL" + ) + self.validate_identity( + "DELETE t1, t2 FROM t1 INNER JOIN t2 INNER JOIN t3 WHERE t1.id = t2.id AND t2.id = t3.id" + ) + self.validate_identity( + "DELETE FROM t1, t2 USING t1 INNER JOIN t2 INNER JOIN t3 WHERE t1.id = t2.id AND t2.id = t3.id" + ) + self.validate_identity( + "INSERT IGNORE INTO subscribers (email) VALUES ('john.doe@gmail.com'), ('jane.smith@ibm.com')" + ) + self.validate_identity( + "INSERT INTO t1 (a, b, c) VALUES (1, 2, 3), (4, 5, 6) ON DUPLICATE KEY UPDATE c = VALUES(a) + VALUES(b)" + ) + self.validate_identity( + "INSERT INTO t1 (a, b) SELECT c, d FROM t2 UNION SELECT e, f FROM t3 ON DUPLICATE KEY UPDATE b = b + c" + ) + self.validate_identity( + "INSERT INTO t1 (a, b, c) VALUES (1, 2, 3) ON DUPLICATE KEY UPDATE c = c + 1" + ) + self.validate_identity( + "INSERT INTO x VALUES (1, 'a', 2.0) ON DUPLICATE KEY UPDATE x.id = 1" ) self.validate_all( @@ -48,6 +69,14 @@ class TestMySQL(Validator): ) def test_identity(self): + self.validate_identity("SELECT /*+ BKA(t1) NO_BKA(t2) */ * FROM t1 INNER JOIN t2") + self.validate_identity("SELECT /*+ MERGE(dt) */ * FROM (SELECT * FROM t1) AS dt") + self.validate_identity("SELECT /*+ INDEX(t, i) */ c1 FROM t WHERE c2 = 'value'") + self.validate_identity("SELECT @a MEMBER OF(@c), @b MEMBER OF(@c)") + self.validate_identity("SELECT JSON_ARRAY(4, 5) MEMBER OF('[[3,4],[4,5]]')") + self.validate_identity("SELECT CAST('[4,5]' AS JSON) MEMBER OF('[[3,4],[4,5]]')") + self.validate_identity("""SELECT 'ab' MEMBER OF('[23, "abc", 17, "ab", 10]')""") + self.validate_identity("""SELECT * FROM foo WHERE 'ab' MEMBER OF(content)""") self.validate_identity("CAST(x AS ENUM('a', 'b'))") self.validate_identity("CAST(x AS SET('a', 'b'))") self.validate_identity("SELECT CURRENT_TIMESTAMP(6)") @@ -61,8 +90,15 @@ class TestMySQL(Validator): self.validate_identity("CREATE TABLE A LIKE B") self.validate_identity("SELECT * FROM t1, t2 FOR SHARE OF t1, t2 SKIP LOCKED") self.validate_identity( + """SELECT * FROM foo WHERE 3 MEMBER OF(JSON_EXTRACT(info, '$.value'))""" + ) + self.validate_identity( "SELECT * FROM t1, t2, t3 FOR SHARE OF t1 NOWAIT FOR UPDATE OF t2, t3 SKIP LOCKED" ) + self.validate_identity( + """SELECT * FROM foo WHERE 3 MEMBER OF(info->'$.value')""", + """SELECT * FROM foo WHERE 3 MEMBER OF(JSON_EXTRACT(info, '$.value'))""", + ) # Index hints self.validate_identity( @@ -403,6 +439,13 @@ class TestMySQL(Validator): self.validate_all("CAST(x AS UNSIGNED)", write={"mysql": "CAST(x AS UNSIGNED)"}) self.validate_all("CAST(x AS UNSIGNED INTEGER)", write={"mysql": "CAST(x AS UNSIGNED)"}) self.validate_all( + """SELECT 17 MEMBER OF('[23, "abc", 17, "ab", 10]')""", + write={ + "": """SELECT JSON_ARRAY_CONTAINS(17, '[23, "abc", 17, "ab", 10]')""", + "mysql": """SELECT 17 MEMBER OF('[23, "abc", 17, "ab", 10]')""", + }, + ) + self.validate_all( "SELECT DATE_ADD('2023-06-23 12:00:00', INTERVAL 2 * 2 MONTH) FROM foo", write={ "mysql": "SELECT DATE_ADD('2023-06-23 12:00:00', INTERVAL (2 * 2) MONTH) FROM foo", diff --git a/tests/dialects/test_postgres.py b/tests/dialects/test_postgres.py index c391052..052d4cc 100644 --- a/tests/dialects/test_postgres.py +++ b/tests/dialects/test_postgres.py @@ -110,7 +110,7 @@ class TestPostgres(Validator): ) @mock.patch("sqlglot.helper.logger") - def test_array_offset(self, mock_logger): + def test_array_offset(self, logger): self.validate_all( "SELECT col[1]", write={ diff --git a/tests/dialects/test_presto.py b/tests/dialects/test_presto.py index 49139f9..45a0cd9 100644 --- a/tests/dialects/test_presto.py +++ b/tests/dialects/test_presto.py @@ -442,7 +442,7 @@ class TestPresto(Validator): ) @mock.patch("sqlglot.helper.logger") - def test_presto(self, mock_logger): + def test_presto(self, logger): self.validate_identity("SELECT * FROM x OFFSET 1 LIMIT 1") self.validate_identity("SELECT * FROM x OFFSET 1 FETCH FIRST 1 ROWS ONLY") self.validate_identity("SELECT BOOL_OR(a > 10) FROM asd AS T(a)") diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py index e9826a6..f7bab4d 100644 --- a/tests/dialects/test_snowflake.py +++ b/tests/dialects/test_snowflake.py @@ -6,6 +6,7 @@ class TestSnowflake(Validator): dialect = "snowflake" def test_snowflake(self): + self.validate_identity("SELECT SUM(amount) FROM mytable GROUP BY ALL") self.validate_identity("WITH x AS (SELECT 1 AS foo) SELECT foo FROM IDENTIFIER('x')") self.validate_identity("WITH x AS (SELECT 1 AS foo) SELECT IDENTIFIER('foo') FROM x") self.validate_identity("INITCAP('iqamqinterestedqinqthisqtopic', 'q')") @@ -33,6 +34,9 @@ class TestSnowflake(Validator): self.validate_identity( 'COPY INTO NEW_TABLE ("foo", "bar") FROM (SELECT $1, $2, $3, $4 FROM @%old_table)' ) + self.validate_identity( + "SELECT state, city, SUM(retail_price * quantity) AS gross_revenue FROM sales GROUP BY ALL" + ) self.validate_all("CAST(x AS BYTEINT)", write={"snowflake": "CAST(x AS INT)"}) self.validate_all("CAST(x AS CHAR VARYING)", write={"snowflake": "CAST(x AS VARCHAR)"}) @@ -611,6 +615,9 @@ class TestSnowflake(Validator): self.validate_identity( "CREATE SCHEMA mytestschema_clone_restore CLONE testschema BEFORE (TIMESTAMP => TO_TIMESTAMP(40 * 365 * 86400))" ) + self.validate_identity( + "CREATE OR REPLACE TABLE EXAMPLE_DB.DEMO.USERS (ID DECIMAL(38, 0) NOT NULL, PRIMARY KEY (ID), FOREIGN KEY (CITY_CODE) REFERENCES EXAMPLE_DB.DEMO.CITIES (CITY_CODE))" + ) self.validate_all( "CREATE OR REPLACE TRANSIENT TABLE a (id INT)", diff --git a/tests/dialects/test_spark.py b/tests/dialects/test_spark.py index 8acc48e..25841c5 100644 --- a/tests/dialects/test_spark.py +++ b/tests/dialects/test_spark.py @@ -1,3 +1,5 @@ +from unittest import mock + from tests.dialects.test_dialect import Validator @@ -148,7 +150,8 @@ TBLPROPERTIES ( }, ) - def test_hint(self): + @mock.patch("sqlglot.generator.logger") + def test_hint(self, logger): self.validate_all( "SELECT /*+ COALESCE(3) */ * FROM x", write={ diff --git a/tests/dialects/test_sqlite.py b/tests/dialects/test_sqlite.py index 583d5be..10da9b0 100644 --- a/tests/dialects/test_sqlite.py +++ b/tests/dialects/test_sqlite.py @@ -20,8 +20,9 @@ class TestSQLite(Validator): CREATE TABLE "Track" ( CONSTRAINT "PK_Track" FOREIGN KEY ("TrackId"), - FOREIGN KEY ("AlbumId") REFERENCES "Album" ("AlbumId") - ON DELETE NO ACTION ON UPDATE NO ACTION, + FOREIGN KEY ("AlbumId") REFERENCES "Album" ( + "AlbumId" + ) ON DELETE NO ACTION ON UPDATE NO ACTION, FOREIGN KEY ("AlbumId") ON DELETE CASCADE ON UPDATE RESTRICT, FOREIGN KEY ("AlbumId") ON DELETE SET NULL ON UPDATE SET DEFAULT ) @@ -29,7 +30,9 @@ class TestSQLite(Validator): write={ "sqlite": """CREATE TABLE "Track" ( CONSTRAINT "PK_Track" FOREIGN KEY ("TrackId"), - FOREIGN KEY ("AlbumId") REFERENCES "Album"("AlbumId") ON DELETE NO ACTION ON UPDATE NO ACTION, + FOREIGN KEY ("AlbumId") REFERENCES "Album" ( + "AlbumId" + ) ON DELETE NO ACTION ON UPDATE NO ACTION, FOREIGN KEY ("AlbumId") ON DELETE CASCADE ON UPDATE RESTRICT, FOREIGN KEY ("AlbumId") ON DELETE SET NULL ON UPDATE SET DEFAULT )""", diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py index 953d64d..ca6d70c 100644 --- a/tests/dialects/test_tsql.py +++ b/tests/dialects/test_tsql.py @@ -17,17 +17,29 @@ class TestTSQL(Validator): self.validate_identity("PRINT @TestVariable") self.validate_identity("SELECT Employee_ID, Department_ID FROM @MyTableVar") self.validate_identity("INSERT INTO @TestTable VALUES (1, 'Value1', 12, 20)") - self.validate_identity( - "SELECT x FROM @MyTableVar AS m JOIN Employee ON m.EmployeeID = Employee.EmployeeID" - ) self.validate_identity('SELECT "x"."y" FROM foo') self.validate_identity("SELECT * FROM #foo") self.validate_identity("SELECT * FROM ##foo") self.validate_identity( + "SELECT x FROM @MyTableVar AS m JOIN Employee ON m.EmployeeID = Employee.EmployeeID" + ) + self.validate_identity( "SELECT DISTINCT DepartmentName, PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY BaseRate) OVER (PARTITION BY DepartmentName) AS MedianCont FROM dbo.DimEmployee" ) self.validate_all( + "SELECT DATEPART(year, TRY_CAST('2017-01-01' AS DATE))", + read={"postgres": "SELECT DATE_PART('year', '2017-01-01'::DATE)"}, + ) + self.validate_all( + "SELECT DATEPART(month, TRY_CAST('2017-03-01' AS DATE))", + read={"postgres": "SELECT DATE_PART('month', '2017-03-01'::DATE)"}, + ) + self.validate_all( + "SELECT DATEPART(day, TRY_CAST('2017-01-02' AS DATE))", + read={"postgres": "SELECT DATE_PART('day', '2017-01-02'::DATE)"}, + ) + self.validate_all( "SELECT CAST([a].[b] AS SMALLINT) FROM foo", write={ "tsql": 'SELECT CAST("a"."b" AS SMALLINT) FROM foo', @@ -281,19 +293,20 @@ WHERE def test_datename(self): self.validate_all( - "SELECT DATENAME(mm,'01-01-1970')", - write={"spark": "SELECT DATE_FORMAT('01-01-1970', 'MMMM')"}, + "SELECT DATENAME(mm,'1970-01-01')", + write={"spark": "SELECT DATE_FORMAT(CAST('1970-01-01' AS TIMESTAMP), 'MMMM')"}, ) self.validate_all( - "SELECT DATENAME(dw,'01-01-1970')", - write={"spark": "SELECT DATE_FORMAT('01-01-1970', 'EEEE')"}, + "SELECT DATENAME(dw,'1970-01-01')", + write={"spark": "SELECT DATE_FORMAT(CAST('1970-01-01' AS TIMESTAMP), 'EEEE')"}, ) def test_datepart(self): self.validate_all( - "SELECT DATEPART(month,'01-01-1970')", - write={"spark": "SELECT DATE_FORMAT('01-01-1970', 'MM')"}, + "SELECT DATEPART(month,'1970-01-01')", + write={"spark": "SELECT DATE_FORMAT(CAST('1970-01-01' AS TIMESTAMP), 'MM')"}, ) + self.validate_identity("DATEPART(YEAR, x)", "FORMAT(CAST(x AS DATETIME2), 'yyyy')") def test_convert_date_format(self): self.validate_all( diff --git a/tests/fixtures/identity.sql b/tests/fixtures/identity.sql index 60a655a..162d627 100644 --- a/tests/fixtures/identity.sql +++ b/tests/fixtures/identity.sql @@ -414,6 +414,8 @@ SELECT 1 FROM a NATURAL LEFT JOIN b SELECT 1 FROM a NATURAL LEFT OUTER JOIN b SELECT 1 FROM a OUTER JOIN b ON a.foo = b.bar SELECT 1 FROM a FULL JOIN b ON a.foo = b.bar +SELECT 1 FROM a JOIN b JOIN c ON b.id = c.id ON a.id = b.id +SELECT * FROM a JOIN b JOIN c USING (id) USING (id) SELECT 1 UNION ALL SELECT 2 SELECT 1 EXCEPT SELECT 2 SELECT 1 EXCEPT SELECT 2 @@ -552,17 +554,17 @@ CREATE TABLE z AS ((WITH cte AS (SELECT 1) SELECT * FROM cte)) CREATE TABLE z (a INT UNIQUE) CREATE TABLE z (a INT AUTO_INCREMENT) CREATE TABLE z (a INT UNIQUE AUTO_INCREMENT) -CREATE TABLE z (a INT REFERENCES parent(b, c)) -CREATE TABLE z (a INT PRIMARY KEY, b INT REFERENCES foo(id)) -CREATE TABLE z (a INT, FOREIGN KEY (a) REFERENCES parent(b, c)) -CREATE TABLE foo (bar INT REFERENCES baz(baz_id) ON DELETE NO ACTION) -CREATE TABLE foo (bar INT REFERENCES baz(baz_id) ON DELETE CASCADE) -CREATE TABLE foo (bar INT REFERENCES baz(baz_id) ON DELETE SET NULL) -CREATE TABLE foo (bar INT REFERENCES baz(baz_id) ON DELETE SET DEFAULT) -CREATE TABLE foo (bar INT REFERENCES baz(baz_id) ON UPDATE NO ACTION) -CREATE TABLE foo (bar INT REFERENCES baz(baz_id) ON UPDATE CASCADE) -CREATE TABLE foo (bar INT REFERENCES baz(baz_id) ON UPDATE SET NULL) -CREATE TABLE foo (bar INT REFERENCES baz(baz_id) ON UPDATE SET DEFAULT) +CREATE TABLE z (a INT REFERENCES parent (b, c)) +CREATE TABLE z (a INT PRIMARY KEY, b INT REFERENCES foo (id)) +CREATE TABLE z (a INT, FOREIGN KEY (a) REFERENCES parent (b, c)) +CREATE TABLE foo (bar INT REFERENCES baz (baz_id) ON DELETE NO ACTION) +CREATE TABLE foo (bar INT REFERENCES baz (baz_id) ON DELETE CASCADE) +CREATE TABLE foo (bar INT REFERENCES baz (baz_id) ON DELETE SET NULL) +CREATE TABLE foo (bar INT REFERENCES baz (baz_id) ON DELETE SET DEFAULT) +CREATE TABLE foo (bar INT REFERENCES baz (baz_id) ON UPDATE NO ACTION) +CREATE TABLE foo (bar INT REFERENCES baz (baz_id) ON UPDATE CASCADE) +CREATE TABLE foo (bar INT REFERENCES baz (baz_id) ON UPDATE SET NULL) +CREATE TABLE foo (bar INT REFERENCES baz (baz_id) ON UPDATE SET DEFAULT) CREATE TABLE asd AS SELECT asd FROM asd WITH NO DATA CREATE TABLE asd AS SELECT asd FROM asd WITH DATA CREATE TABLE products (x INT GENERATED BY DEFAULT AS IDENTITY) @@ -573,7 +575,7 @@ CREATE TABLE IF NOT EXISTS customer (pk BIGINT NOT NULL GENERATED ALWAYS AS IDEN CREATE TABLE customer (pk BIGINT NOT NULL GENERATED ALWAYS AS IDENTITY (START WITH 10 INCREMENT BY 1 MINVALUE -1 MAXVALUE 1 NO CYCLE)) CREATE TABLE customer (pk BIGINT NOT NULL GENERATED ALWAYS AS IDENTITY (START WITH 10)) CREATE TABLE customer (pk BIGINT NOT NULL GENERATED ALWAYS AS IDENTITY (CYCLE)) -CREATE TABLE foo (baz_id INT REFERENCES baz(id) DEFERRABLE) +CREATE TABLE foo (baz_id INT REFERENCES baz (id) DEFERRABLE) CREATE TABLE foo (baz CHAR(4) CHARACTER SET LATIN UPPERCASE NOT CASESPECIFIC COMPRESS 'a') CREATE TABLE foo (baz DATE FORMAT 'YYYY/MM/DD' TITLE 'title' INLINE LENGTH 1 COMPRESS ('a', 'b')) CREATE TABLE t (title TEXT) @@ -648,7 +650,7 @@ ANALYZE a.y DELETE FROM x WHERE y > 1 DELETE FROM y DELETE FROM event USING sales WHERE event.eventid = sales.eventid -DELETE FROM event USING sales, USING bla WHERE event.eventid = sales.eventid +DELETE FROM event USING sales, bla WHERE event.eventid = sales.eventid DELETE FROM event USING sales AS s WHERE event.eventid = s.eventid DELETE FROM event AS event USING sales AS s WHERE event.eventid = s.eventid PREPARE statement @@ -794,6 +796,7 @@ ALTER TABLE a ADD FOREIGN KEY (x, y) REFERENCES bla SELECT partition FROM a SELECT end FROM a SELECT id FROM b.a AS a QUALIFY ROW_NUMBER() OVER (PARTITION BY br ORDER BY sadf DESC) = 1 +SELECT * FROM x WHERE a GROUP BY a HAVING b SORT BY s ORDER BY c LIMIT d SELECT LEFT.FOO FROM BLA AS LEFT SELECT RIGHT.FOO FROM BLA AS RIGHT SELECT LEFT FROM LEFT LEFT JOIN RIGHT RIGHT JOIN LEFT @@ -834,3 +837,4 @@ SELECT * FROM case SELECT * FROM schema.case SELECT * FROM current_date SELECT * FROM schema.current_date +SELECT /*+ SOME_HINT(foo) */ 1 diff --git a/tests/fixtures/pretty.sql b/tests/fixtures/pretty.sql index 46cd6d8..1a61334 100644 --- a/tests/fixtures/pretty.sql +++ b/tests/fixtures/pretty.sql @@ -363,3 +363,23 @@ SELECT A.* EXCEPT (A.COL_1, A.COL_2) FROM TABLE_1 AS A; + +SELECT * +FROM a +JOIN b + JOIN c + ON b.id = c.id + ON a.id = b.id +CROSS JOIN d +JOIN e + ON d.id = e.id; +SELECT + * +FROM a +JOIN b + JOIN c + ON b.id = c.id + ON a.id = b.id +CROSS JOIN d +JOIN e + ON d.id = e.id; diff --git a/tests/test_executor.py b/tests/test_executor.py index 6dd530f..9dacbbf 100644 --- a/tests/test_executor.py +++ b/tests/test_executor.py @@ -58,8 +58,11 @@ class TestExecutor(unittest.TestCase): source.rename(columns={column: target.columns[i]}, inplace=True) def test_py_dialect(self): - self.assertEqual(Python().generate(parse_one("'x '''")), r"'x \''") - self.assertEqual(Python().generate(parse_one("MAP([1], [2])")), "MAP([1], [2])") + generate = Python().generate + self.assertEqual(generate(parse_one("'x '''")), r"'x \''") + self.assertEqual(generate(parse_one("MAP([1], [2])")), "MAP([1], [2])") + self.assertEqual(generate(parse_one("1 is null")), "1 == None") + self.assertEqual(generate(parse_one("x is null")), "scope[None][x] is None") def test_optimized_tpch(self): for i, (sql, optimized) in enumerate(self.sqls[:20], start=1): @@ -620,6 +623,7 @@ class TestExecutor(unittest.TestCase): ("TIMESTRTOTIME('2022-01-01')", datetime.datetime(2022, 1, 1)), ("LEFT('12345', 3)", "123"), ("RIGHT('12345', 3)", "345"), + ("DATEDIFF('2022-01-03'::date, '2022-01-01'::TIMESTAMP::DATE)", 2), ]: with self.subTest(sql): result = execute(f"SELECT {sql}") @@ -699,6 +703,21 @@ class TestExecutor(unittest.TestCase): [(2, 25.0)], ("_col_0", "_col_1"), ), + ( + "SELECT a FROM x GROUP BY a ORDER BY AVG(b)", + [(2,), (1,), (3,)], + ("a",), + ), + ( + "SELECT a, SUM(b) FROM x GROUP BY a ORDER BY COUNT(*)", + [(3, 28), (1, 50), (2, 45)], + ("a", "_col_1"), + ), + ( + "SELECT a, SUM(b) FROM x GROUP BY a ORDER BY COUNT(*) DESC", + [(1, 50), (2, 45), (3, 28)], + ("a", "_col_1"), + ), ): with self.subTest(sql): result = execute(sql, tables=tables) diff --git a/tests/test_expressions.py b/tests/test_expressions.py index f83addb..f050c0b 100644 --- a/tests/test_expressions.py +++ b/tests/test_expressions.py @@ -15,14 +15,14 @@ class TestExpressions(unittest.TestCase): self.assertEqual(parse_one("x(1)").find(exp.Literal).depth, 1) def test_eq(self): - self.assertEqual(exp.to_identifier("a"), exp.to_identifier("A")) + self.assertNotEqual(exp.to_identifier("a"), exp.to_identifier("A")) self.assertEqual( exp.Column(table=exp.to_identifier("b"), this=exp.to_identifier("b")), exp.Column(this=exp.to_identifier("b"), table=exp.to_identifier("b")), ) - self.assertEqual(exp.to_identifier("a", quoted=True), exp.to_identifier("A")) + self.assertNotEqual(exp.to_identifier("a", quoted=True), exp.to_identifier("A")) self.assertNotEqual(exp.to_identifier("A", quoted=True), exp.to_identifier("A")) self.assertNotEqual( exp.to_identifier("A", quoted=True), exp.to_identifier("a", quoted=True) @@ -31,9 +31,9 @@ class TestExpressions(unittest.TestCase): self.assertNotEqual(parse_one("'1'"), parse_one("1")) self.assertEqual(parse_one("`a`", read="hive"), parse_one('"a"')) self.assertEqual(parse_one("`a`", read="hive"), parse_one('"a" ')) - self.assertEqual(parse_one("`a`.b", read="hive"), parse_one('"a"."b"')) + self.assertEqual(parse_one("`a`.`b`", read="hive"), parse_one('"a"."b"')) self.assertEqual(parse_one("select a, b+1"), parse_one("SELECT a, b + 1")) - self.assertEqual(parse_one("`a`.`b`.`c`", read="hive"), parse_one("a.b.c")) + self.assertNotEqual(parse_one("`a`.`b`.`c`", read="hive"), parse_one("a.b.c")) self.assertNotEqual(parse_one("a.b.c.d", read="hive"), parse_one("a.b.c")) self.assertEqual(parse_one("a.b.c.d", read="hive"), parse_one("a.b.c.d")) self.assertEqual(parse_one("a + b * c - 1.0"), parse_one("a+b*c-1.0")) @@ -338,7 +338,7 @@ class TestExpressions(unittest.TestCase): { parse_one("select a.b"), parse_one("1+2"), - parse_one('"a".b'), + parse_one('"a"."b"'), parse_one("a.b.c.d"), }, { diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index 40eef9f..0608903 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -146,7 +146,8 @@ class TestOptimizer(unittest.TestCase): df2 = self.conn.execute(optimized.sql(pretty=pretty, dialect="duckdb")).df() assert_frame_equal(df1, df2) - def test_optimize(self): + @patch("sqlglot.generator.logger") + def test_optimize(self, logger): self.assertEqual(optimizer.optimize("x = 1 + 1", identify=None).sql(), "x = 2") schema = { @@ -199,7 +200,8 @@ class TestOptimizer(unittest.TestCase): self.check_file("normalize", normalize) - def test_qualify_columns(self): + @patch("sqlglot.generator.logger") + def test_qualify_columns(self, logger): self.assertEqual( optimizer.qualify_columns.qualify_columns( parse_one("WITH x AS (SELECT a FROM db.y) SELECT z FROM db.x"), @@ -229,6 +231,17 @@ class TestOptimizer(unittest.TestCase): 'WITH "x" AS (SELECT "y"."a" AS "a" FROM "DB"."Y" AS "y" CROSS JOIN "a"."b"."INFORMATION_SCHEMA"."COLUMNS" AS "columns") SELECT "x"."a" AS "a" FROM "x"', ) + self.assertEqual( + optimizer.qualify.qualify( + parse_one( + "CREATE FUNCTION udfs.`myTest`(`x` FLOAT64) AS (1)", + read="bigquery", + ), + dialect="bigquery", + ).sql(dialect="bigquery"), + "CREATE FUNCTION `udfs`.`myTest`(`x` FLOAT64) AS (1)", + ) + self.check_file("qualify_columns", qualify_columns, execute=True, schema=self.schema) def test_qualify_columns__with_invisible(self): @@ -307,7 +320,8 @@ class TestOptimizer(unittest.TestCase): pretty=True, ) - def test_merge_subqueries(self): + @patch("sqlglot.generator.logger") + def test_merge_subqueries(self, logger): optimize = partial( optimizer.optimize, rules=[ @@ -575,7 +589,9 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') def test_function_annotation(self): schema = {"x": {"cola": "VARCHAR", "colb": "CHAR"}} - sql = "SELECT x.cola || TRIM(x.colb) AS col, DATE(x.colb) FROM x AS x" + sql = ( + "SELECT x.cola || TRIM(x.colb) AS col, DATE(x.colb), DATEFROMPARTS(y, m, d) FROM x AS x" + ) expression = annotate_types(parse_one(sql), schema=schema) concat_expr_alias = expression.expressions[0] @@ -590,6 +606,9 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') date_expr = expression.expressions[1] self.assertEqual(date_expr.type.this, exp.DataType.Type.DATE) + date_expr = expression.expressions[2] + self.assertEqual(date_expr.type.this, exp.DataType.Type.DATE) + sql = "SELECT CASE WHEN 1=1 THEN x.cola ELSE x.colb END AS col FROM x AS x" case_expr_alias = annotate_types(parse_one(sql), schema=schema).expressions[0] diff --git a/tests/test_parser.py b/tests/test_parser.py index 96192cd..2fa6a09 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -1,3 +1,4 @@ +import time import unittest from unittest.mock import patch @@ -67,7 +68,7 @@ class TestParser(unittest.TestCase): }, ] with self.assertRaises(ParseError) as ctx: - parse_one("SELECT 1;", "sqlite", [exp.From, exp.Join]) + parse_one("SELECT 1;", "sqlite", into=[exp.From, exp.Join]) self.assertEqual(str(ctx.exception), expected_message) self.assertEqual(ctx.exception.errors, expected_errors) @@ -318,6 +319,7 @@ class TestParser(unittest.TestCase): self.assertIsInstance(parse_one("TIMESTAMP()"), exp.Func) self.assertIsInstance(parse_one("map.x"), exp.Column) self.assertIsInstance(parse_one("CAST(x AS CHAR(5))").to.expressions[0], exp.DataTypeSize) + self.assertEqual(parse_one("1::int64", dialect="bigquery"), parse_one("CAST(1 AS BIGINT)")) def test_set_expression(self): set_ = parse_one("SET") @@ -522,6 +524,55 @@ class TestParser(unittest.TestCase): columns = expr.args["from"].this.args["pivots"][0].args["columns"] self.assertEqual(expected_columns, [col.sql(dialect=dialect) for col in columns]) + def test_parse_nested(self): + now = time.time() + query = parse_one( + """ + select * + FROM a + LEFT JOIN b ON a.id = b.id + LEFT JOIN b ON a.id = b.id + LEFT JOIN b ON a.id = b.id + LEFT JOIN b ON a.id = b.id + LEFT JOIN b ON a.id = b.id + LEFT JOIN b ON a.id = b.id + LEFT JOIN b ON a.id = b.id + LEFT JOIN b ON a.id = b.id + LEFT JOIN b ON a.id = b.id + LEFT JOIN b ON a.id = b.id + LEFT JOIN b ON a.id = b.id + LEFT JOIN b ON a.id = b.id + LEFT JOIN b ON a.id = b.id + LEFT JOIN b ON a.id = b.id + LEFT JOIN b ON a.id = b.id + LEFT JOIN b ON a.id = b.id + LEFT JOIN b ON a.id = b.id + LEFT JOIN b ON a.id = b.id + LEFT JOIN b ON a.id = b.id + LEFT JOIN b ON a.id = b.id + LEFT JOIN b ON a.id = b.id + LEFT JOIN b ON a.id = b.id + LEFT JOIN b ON a.id = b.id + LEFT JOIN b ON a.id = b.id + LEFT JOIN b ON a.id = b.id + LEFT JOIN b ON a.id = b.id + LEFT JOIN b ON a.id = b.id + LEFT JOIN b ON a.id = b.id + LEFT JOIN b ON a.id = b.id + LEFT JOIN b ON a.id = b.id + LEFT JOIN b ON a.id = b.id + LEFT JOIN b ON a.id = b.id + LEFT JOIN b ON a.id = b.id + LEFT JOIN b ON a.id = b.id + LEFT JOIN b ON a.id = b.id + LEFT JOIN b ON a.id = b.id + LEFT JOIN b ON a.id = b.id + LEFT JOIN b ON a.id = b.id + """ + ) + self.assertIsNotNone(query) + self.assertLessEqual(time.time() - now, 0.1) + def test_parse_properties(self): self.assertEqual( parse_one("create materialized table x").sql(), "CREATE MATERIALIZED TABLE x" diff --git a/tests/test_transpile.py b/tests/test_transpile.py index 8d762d3..1138b4e 100644 --- a/tests/test_transpile.py +++ b/tests/test_transpile.py @@ -511,11 +511,11 @@ FROM v""", ) @mock.patch("sqlglot.helper.logger") - def test_index_offset(self, mock_logger): + def test_index_offset(self, logger): self.validate("x[0]", "x[1]", write="presto", identity=False) self.validate("x[1]", "x[0]", read="presto", identity=False) - mock_logger.warning.assert_any_call("Applying array index offset (%s)", 1) - mock_logger.warning.assert_any_call("Applying array index offset (%s)", -1) + logger.warning.assert_any_call("Applying array index offset (%s)", 1) + logger.warning.assert_any_call("Applying array index offset (%s)", -1) self.validate("x[x - 1]", "x[x - 1]", write="presto", identity=False) self.validate( |