From 93346175ed97c685979fba99a6ae68268484d8c1 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Mon, 17 Jun 2024 11:15:16 +0200 Subject: Adding upstream version 25.1.0. Signed-off-by: Daniel Baumann --- tests/dialects/test_bigquery.py | 15 +++++++- tests/dialects/test_duckdb.py | 16 ++++++++ tests/dialects/test_mysql.py | 11 ++++-- tests/dialects/test_oracle.py | 70 +++++++++++++++++++++++++++++++++-- tests/dialects/test_postgres.py | 23 ++++++++++++ tests/dialects/test_presto.py | 2 + tests/dialects/test_redshift.py | 9 +++++ tests/dialects/test_snowflake.py | 12 ++++-- tests/dialects/test_spark.py | 11 ++++++ tests/dialects/test_sqlite.py | 1 + tests/dialects/test_tsql.py | 67 ++++++++++++++++----------------- tests/fixtures/identity.sql | 1 + tests/fixtures/optimizer/simplify.sql | 3 ++ tests/test_optimizer.py | 52 ++++++++++++++++++++------ tests/test_parser.py | 7 +++- 15 files changed, 240 insertions(+), 60 deletions(-) (limited to 'tests') diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py index bfaf009..ae8ed16 100644 --- a/tests/dialects/test_bigquery.py +++ b/tests/dialects/test_bigquery.py @@ -20,6 +20,14 @@ class TestBigQuery(Validator): maxDiff = None def test_bigquery(self): + self.validate_all( + "EXTRACT(HOUR FROM DATETIME(2008, 12, 25, 15, 30, 00))", + write={ + "bigquery": "EXTRACT(HOUR FROM DATETIME(2008, 12, 25, 15, 30, 00))", + "duckdb": "EXTRACT(HOUR FROM MAKE_TIMESTAMP(2008, 12, 25, 15, 30, 00))", + "snowflake": "DATE_PART(HOUR, TIMESTAMP_FROM_PARTS(2008, 12, 25, 15, 30, 00))", + }, + ) self.validate_identity( """CREATE TEMPORARY FUNCTION FOO() RETURNS STRING @@ -619,9 +627,9 @@ LANGUAGE js AS 'SELECT TIMESTAMP_ADD(TIMESTAMP "2008-12-25 15:30:00+00", INTERVAL 10 MINUTE)', write={ "bigquery": "SELECT TIMESTAMP_ADD(CAST('2008-12-25 15:30:00+00' AS TIMESTAMP), INTERVAL 10 MINUTE)", - "databricks": "SELECT DATEADD(MINUTE, 10, CAST('2008-12-25 15:30:00+00' AS TIMESTAMP))", + "databricks": "SELECT DATE_ADD(MINUTE, 10, CAST('2008-12-25 15:30:00+00' AS TIMESTAMP))", "mysql": "SELECT DATE_ADD(TIMESTAMP('2008-12-25 15:30:00+00'), INTERVAL 10 MINUTE)", - "spark": "SELECT DATEADD(MINUTE, 10, CAST('2008-12-25 15:30:00+00' AS TIMESTAMP))", + "spark": "SELECT DATE_ADD(MINUTE, 10, CAST('2008-12-25 15:30:00+00' AS TIMESTAMP))", }, ) self.validate_all( @@ -761,12 +769,15 @@ LANGUAGE js AS "clickhouse": "SHA256(x)", "presto": "SHA256(x)", "trino": "SHA256(x)", + "postgres": "SHA256(x)", }, write={ "bigquery": "SHA256(x)", "spark2": "SHA2(x, 256)", "clickhouse": "SHA256(x)", + "postgres": "SHA256(x)", "presto": "SHA256(x)", + "redshift": "SHA2(x, 256)", "trino": "SHA256(x)", }, ) diff --git a/tests/dialects/test_duckdb.py b/tests/dialects/test_duckdb.py index cd68ff9..2bde478 100644 --- a/tests/dialects/test_duckdb.py +++ b/tests/dialects/test_duckdb.py @@ -18,6 +18,13 @@ class TestDuckDB(Validator): "WITH _data AS (SELECT [STRUCT(1 AS a, 2 AS b), STRUCT(2 AS a, 3 AS b)] AS col) SELECT col.b FROM _data, UNNEST(_data.col) AS col WHERE col.a = 1", ) + self.validate_all( + "SELECT straight_join", + write={ + "duckdb": "SELECT straight_join", + "mysql": "SELECT `straight_join`", + }, + ) self.validate_all( "SELECT CAST('2020-01-01 12:05:01' AS TIMESTAMP)", read={ @@ -278,6 +285,7 @@ class TestDuckDB(Validator): self.validate_identity("FROM tbl", "SELECT * FROM tbl") self.validate_identity("x -> '$.family'") self.validate_identity("CREATE TABLE color (name ENUM('RED', 'GREEN', 'BLUE'))") + self.validate_identity("SELECT * FROM foo WHERE bar > $baz AND bla = $bob") self.validate_identity( "SELECT * FROM x LEFT JOIN UNNEST(y)", "SELECT * FROM x LEFT JOIN UNNEST(y) ON TRUE" ) @@ -1000,6 +1008,7 @@ class TestDuckDB(Validator): self.validate_identity("CAST(x AS CHAR)", "CAST(x AS TEXT)") self.validate_identity("CAST(x AS BPCHAR)", "CAST(x AS TEXT)") self.validate_identity("CAST(x AS STRING)", "CAST(x AS TEXT)") + self.validate_identity("CAST(x AS VARCHAR)", "CAST(x AS TEXT)") self.validate_identity("CAST(x AS INT1)", "CAST(x AS TINYINT)") self.validate_identity("CAST(x AS FLOAT4)", "CAST(x AS REAL)") self.validate_identity("CAST(x AS FLOAT)", "CAST(x AS REAL)") @@ -1027,6 +1036,13 @@ class TestDuckDB(Validator): "CAST([{'a': 1}] AS STRUCT(a BIGINT)[])", ) + self.validate_all( + "CAST(x AS VARCHAR(5))", + write={ + "duckdb": "CAST(x AS TEXT)", + "postgres": "CAST(x AS TEXT)", + }, + ) self.validate_all( "CAST(x AS DECIMAL(38, 0))", read={ diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py index fdb7e91..280ebbf 100644 --- a/tests/dialects/test_mysql.py +++ b/tests/dialects/test_mysql.py @@ -21,6 +21,9 @@ class TestMySQL(Validator): self.validate_identity("CREATE TABLE foo (a BIGINT, FULLTEXT INDEX (b))") self.validate_identity("CREATE TABLE foo (a BIGINT, SPATIAL INDEX (b))") self.validate_identity("ALTER TABLE t1 ADD COLUMN x INT, ALGORITHM=INPLACE, LOCK=EXCLUSIVE") + self.validate_identity("ALTER TABLE t ADD INDEX `i` (`c`)") + self.validate_identity("ALTER TABLE t ADD UNIQUE `i` (`c`)") + self.validate_identity("ALTER TABLE test_table MODIFY COLUMN test_column LONGTEXT") self.validate_identity( "CREATE TABLE `oauth_consumer` (`key` VARCHAR(32) NOT NULL, UNIQUE `OAUTH_CONSUMER_KEY` (`key`))" ) @@ -60,6 +63,10 @@ class TestMySQL(Validator): self.validate_identity( "CREATE OR REPLACE VIEW my_view AS SELECT column1 AS `boo`, column2 AS `foo` FROM my_table WHERE column3 = 'some_value' UNION SELECT q.* FROM fruits_table, JSON_TABLE(Fruits, '$[*]' COLUMNS(id VARCHAR(255) PATH '$.$id', value VARCHAR(255) PATH '$.value')) AS q", ) + self.validate_identity( + "ALTER TABLE t ADD KEY `i` (`c`)", + "ALTER TABLE t ADD INDEX `i` (`c`)", + ) self.validate_identity( "CREATE TABLE `foo` (`id` char(36) NOT NULL DEFAULT (uuid()), PRIMARY KEY (`id`), UNIQUE KEY `id` (`id`))", "CREATE TABLE `foo` (`id` CHAR(36) NOT NULL DEFAULT (UUID()), PRIMARY KEY (`id`), UNIQUE `id` (`id`))", @@ -76,9 +83,6 @@ class TestMySQL(Validator): "ALTER TABLE test_table ALTER COLUMN test_column SET DATA TYPE LONGTEXT", "ALTER TABLE test_table MODIFY COLUMN test_column LONGTEXT", ) - self.validate_identity( - "ALTER TABLE test_table MODIFY COLUMN test_column LONGTEXT", - ) self.validate_identity( "CREATE TABLE t (c DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP) DEFAULT CHARSET=utf8 ROW_FORMAT=DYNAMIC", "CREATE TABLE t (c DATETIME DEFAULT CURRENT_TIMESTAMP() ON UPDATE CURRENT_TIMESTAMP()) DEFAULT CHARACTER SET=utf8 ROW_FORMAT=DYNAMIC", @@ -113,6 +117,7 @@ class TestMySQL(Validator): ) def test_identity(self): + self.validate_identity("SELECT e.* FROM e STRAIGHT_JOIN p ON e.x = p.y") self.validate_identity("ALTER TABLE test_table ALTER COLUMN test_column SET DEFAULT 1") self.validate_identity("SELECT DATE_FORMAT(NOW(), '%Y-%m-%d %H:%i:00.0000')") self.validate_identity("SELECT @var1 := 1, @var2") diff --git a/tests/dialects/test_oracle.py b/tests/dialects/test_oracle.py index 526b0b5..7cc4d72 100644 --- a/tests/dialects/test_oracle.py +++ b/tests/dialects/test_oracle.py @@ -1,5 +1,5 @@ -from sqlglot import exp -from sqlglot.errors import UnsupportedError +from sqlglot import exp, UnsupportedError +from sqlglot.dialects.oracle import eliminate_join_marks from tests.dialects.test_dialect import Validator @@ -43,6 +43,7 @@ class TestOracle(Validator): self.validate_identity("SELECT * FROM table_name SAMPLE (25) s") self.validate_identity("SELECT COUNT(*) * 10 FROM orders SAMPLE (10) SEED (1)") self.validate_identity("SELECT * FROM V$SESSION") + self.validate_identity("SELECT TO_DATE('January 15, 1989, 11:00 A.M.')") self.validate_identity( "SELECT last_name, employee_id, manager_id, LEVEL FROM employees START WITH employee_id = 100 CONNECT BY PRIOR employee_id = manager_id ORDER SIBLINGS BY last_name" ) @@ -249,7 +250,8 @@ class TestOracle(Validator): self.validate_identity("SELECT e1.x, e2.x FROM e e1, e e2 WHERE e1.y (+) = e2.y") self.validate_all( - "SELECT e1.x, e2.x FROM e e1, e e2 WHERE e1.y = e2.y (+)", write={"": UnsupportedError} + "SELECT e1.x, e2.x FROM e e1, e e2 WHERE e1.y = e2.y (+)", + write={"": UnsupportedError}, ) self.validate_all( "SELECT e1.x, e2.x FROM e e1, e e2 WHERE e1.y = e2.y (+)", @@ -413,3 +415,65 @@ WHERE for query in (f"{body}{start}{connect}", f"{body}{connect}{start}"): self.validate_identity(query, pretty, pretty=True) + + def test_eliminate_join_marks(self): + test_sql = [ + ( + "SELECT T1.d, T2.c FROM T1, T2 WHERE T1.x = T2.x (+) and T2.y (+) > 5", + "SELECT T1.d, T2.c FROM T1 LEFT JOIN T2 ON T1.x = T2.x AND T2.y > 5", + ), + ( + "SELECT T1.d, T2.c FROM T1, T2 WHERE T1.x = T2.x (+) and T2.y (+) IS NULL", + "SELECT T1.d, T2.c FROM T1 LEFT JOIN T2 ON T1.x = T2.x AND T2.y IS NULL", + ), + ( + "SELECT T1.d, T2.c FROM T1, T2 WHERE T1.x = T2.x (+) and T2.y IS NULL", + "SELECT T1.d, T2.c FROM T1 LEFT JOIN T2 ON T1.x = T2.x WHERE T2.y IS NULL", + ), + ( + "SELECT T1.d, T2.c FROM T1, T2 WHERE T1.x = T2.x (+) and T1.Z > 4", + "SELECT T1.d, T2.c FROM T1 LEFT JOIN T2 ON T1.x = T2.x WHERE T1.Z > 4", + ), + ( + "SELECT * FROM table1, table2 WHERE table1.column = table2.column(+)", + "SELECT * FROM table1 LEFT JOIN table2 ON table1.column = table2.column", + ), + ( + "SELECT * FROM table1, table2, table3, table4 WHERE table1.column = table2.column(+) and table2.column >= table3.column(+) and table1.column = table4.column(+)", + "SELECT * FROM table1 LEFT JOIN table2 ON table1.column = table2.column LEFT JOIN table3 ON table2.column >= table3.column LEFT JOIN table4 ON table1.column = table4.column", + ), + ( + "SELECT * FROM table1, table2, table3 WHERE table1.column = table2.column(+) and table2.column >= table3.column(+)", + "SELECT * FROM table1 LEFT JOIN table2 ON table1.column = table2.column LEFT JOIN table3 ON table2.column >= table3.column", + ), + ( + "SELECT table1.id, table2.cloumn1, table3.id FROM table1, table2, (SELECT tableInner1.id FROM tableInner1, tableInner2 WHERE tableInner1.id = tableInner2.id(+)) AS table3 WHERE table1.id = table2.id(+) and table1.id = table3.id(+)", + "SELECT table1.id, table2.cloumn1, table3.id FROM table1 LEFT JOIN table2 ON table1.id = table2.id LEFT JOIN (SELECT tableInner1.id FROM tableInner1 LEFT JOIN tableInner2 ON tableInner1.id = tableInner2.id) table3 ON table1.id = table3.id", + ), + # 2 join marks on one side of predicate + ( + "SELECT * FROM table1, table2 WHERE table1.column = table2.column1(+) + table2.column2(+)", + "SELECT * FROM table1 LEFT JOIN table2 ON table1.column = table2.column1 + table2.column2", + ), + # join mark and expression + ( + "SELECT * FROM table1, table2 WHERE table1.column = table2.column1(+) + 25", + "SELECT * FROM table1 LEFT JOIN table2 ON table1.column = table2.column1 + 25", + ), + ] + + for original, expected in test_sql: + with self.subTest(original): + self.assertEqual( + eliminate_join_marks(self.parse_one(original)).sql(dialect=self.dialect), + expected, + ) + + def test_query_restrictions(self): + for restriction in ("READ ONLY", "CHECK OPTION"): + for constraint_name in (" CONSTRAINT name", ""): + with self.subTest(f"Restriction: {restriction}"): + self.validate_identity(f"SELECT * FROM tbl WITH {restriction}{constraint_name}") + self.validate_identity( + f"CREATE VIEW view AS SELECT * FROM tbl WITH {restriction}{constraint_name}" + ) diff --git a/tests/dialects/test_postgres.py b/tests/dialects/test_postgres.py index 74753be..071677d 100644 --- a/tests/dialects/test_postgres.py +++ b/tests/dialects/test_postgres.py @@ -8,6 +8,7 @@ class TestPostgres(Validator): dialect = "postgres" def test_postgres(self): + self.validate_identity("SHA384(x)") self.validate_identity( 'CREATE TABLE x (a TEXT COLLATE "de_DE")', "CREATE TABLE x (a TEXT COLLATE de_DE)" ) @@ -724,6 +725,28 @@ class TestPostgres(Validator): self.validate_identity("cast(a as FLOAT8)", "CAST(a AS DOUBLE PRECISION)") self.validate_identity("cast(a as FLOAT4)", "CAST(a AS REAL)") + self.validate_all( + "1 / DIV(4, 2)", + read={ + "postgres": "1 / DIV(4, 2)", + }, + write={ + "sqlite": "1 / CAST(CAST(CAST(4 AS REAL) / 2 AS INTEGER) AS REAL)", + "duckdb": "1 / CAST(4 // 2 AS DECIMAL)", + "bigquery": "1 / CAST(DIV(4, 2) AS NUMERIC)", + }, + ) + self.validate_all( + "CAST(DIV(4, 2) AS DECIMAL(5, 3))", + read={ + "duckdb": "CAST(4 // 2 AS DECIMAL(5, 3))", + }, + write={ + "duckdb": "CAST(CAST(4 // 2 AS DECIMAL) AS DECIMAL(5, 3))", + "postgres": "CAST(DIV(4, 2) AS DECIMAL(5, 3))", + }, + ) + def test_ddl(self): # Checks that user-defined types are parsed into DataType instead of Identifier self.parse_one("CREATE TABLE t (a udt)").this.expressions[0].args["kind"].assert_is( diff --git a/tests/dialects/test_presto.py b/tests/dialects/test_presto.py index f1bbcc1..ebb270a 100644 --- a/tests/dialects/test_presto.py +++ b/tests/dialects/test_presto.py @@ -564,6 +564,7 @@ class TestPresto(Validator): self.validate_all( f"{prefix}'Hello winter \\2603 !'", write={ + "oracle": "U'Hello winter \\2603 !'", "presto": "U&'Hello winter \\2603 !'", "snowflake": "'Hello winter \\u2603 !'", "spark": "'Hello winter \\u2603 !'", @@ -572,6 +573,7 @@ class TestPresto(Validator): self.validate_all( f"{prefix}'Hello winter #2603 !' UESCAPE '#'", write={ + "oracle": "U'Hello winter \\2603 !'", "presto": "U&'Hello winter #2603 !' UESCAPE '#'", "snowflake": "'Hello winter \\u2603 !'", "spark": "'Hello winter \\u2603 !'", diff --git a/tests/dialects/test_redshift.py b/tests/dialects/test_redshift.py index 844fe46..69793c7 100644 --- a/tests/dialects/test_redshift.py +++ b/tests/dialects/test_redshift.py @@ -281,6 +281,9 @@ class TestRedshift(Validator): "redshift": "SELECT DATEADD(MONTH, 18, '2008-02-28')", "snowflake": "SELECT DATEADD(MONTH, 18, CAST('2008-02-28' AS TIMESTAMP))", "tsql": "SELECT DATEADD(MONTH, 18, CAST('2008-02-28' AS DATETIME2))", + "spark": "SELECT DATE_ADD(MONTH, 18, '2008-02-28')", + "spark2": "SELECT ADD_MONTHS('2008-02-28', 18)", + "databricks": "SELECT DATE_ADD(MONTH, 18, '2008-02-28')", }, ) self.validate_all( @@ -585,3 +588,9 @@ FROM ( self.assertEqual( ast.sql("redshift"), "SELECT * FROM x AS a, a.b AS c, c.d.e AS f, f.g.h.i.j.k AS l" ) + + def test_join_markers(self): + self.validate_identity( + "select a.foo, b.bar, a.baz from a, b where a.baz = b.baz (+)", + "SELECT a.foo, b.bar, a.baz FROM a, b WHERE a.baz = b.baz (+)", + ) diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py index 9d9371d..1286436 100644 --- a/tests/dialects/test_snowflake.py +++ b/tests/dialects/test_snowflake.py @@ -125,6 +125,10 @@ WHERE "SELECT a:from::STRING, a:from || ' test' ", "SELECT CAST(GET_PATH(a, 'from') AS TEXT), GET_PATH(a, 'from') || ' test'", ) + self.validate_identity( + "SELECT a:select", + "SELECT GET_PATH(a, 'select')", + ) self.validate_identity("x:from", "GET_PATH(x, 'from')") self.validate_identity( "value:values::string::int", @@ -1196,16 +1200,16 @@ WHERE for constraint_prefix in ("WITH ", ""): with self.subTest(f"Constraint prefix: {constraint_prefix}"): self.validate_identity( - f"CREATE TABLE t (id INT {constraint_prefix}MASKING POLICY p)", - "CREATE TABLE t (id INT MASKING POLICY p)", + f"CREATE TABLE t (id INT {constraint_prefix}MASKING POLICY p.q.r)", + "CREATE TABLE t (id INT MASKING POLICY p.q.r)", ) self.validate_identity( f"CREATE TABLE t (id INT {constraint_prefix}MASKING POLICY p USING (c1, c2, c3))", "CREATE TABLE t (id INT MASKING POLICY p USING (c1, c2, c3))", ) self.validate_identity( - f"CREATE TABLE t (id INT {constraint_prefix}PROJECTION POLICY p)", - "CREATE TABLE t (id INT PROJECTION POLICY p)", + f"CREATE TABLE t (id INT {constraint_prefix}PROJECTION POLICY p.q.r)", + "CREATE TABLE t (id INT PROJECTION POLICY p.q.r)", ) self.validate_identity( f"CREATE TABLE t (id INT {constraint_prefix}TAG (key1='value_1', key2='value_2'))", diff --git a/tests/dialects/test_spark.py b/tests/dialects/test_spark.py index ecc152f..bff91bf 100644 --- a/tests/dialects/test_spark.py +++ b/tests/dialects/test_spark.py @@ -563,6 +563,7 @@ TBLPROPERTIES ( "SELECT DATE_ADD(my_date_column, 1)", write={ "spark": "SELECT DATE_ADD(my_date_column, 1)", + "spark2": "SELECT DATE_ADD(my_date_column, 1)", "bigquery": "SELECT DATE_ADD(CAST(CAST(my_date_column AS DATETIME) AS DATE), INTERVAL 1 DAY)", }, ) @@ -675,6 +676,16 @@ TBLPROPERTIES ( "spark": "SELECT ARRAY_SORT(x)", }, ) + self.validate_all( + "SELECT DATE_ADD(MONTH, 20, col)", + read={ + "spark": "SELECT TIMESTAMPADD(MONTH, 20, col)", + }, + write={ + "spark": "SELECT DATE_ADD(MONTH, 20, col)", + "databricks": "SELECT DATE_ADD(MONTH, 20, col)", + }, + ) def test_bool_or(self): self.validate_all( diff --git a/tests/dialects/test_sqlite.py b/tests/dialects/test_sqlite.py index f3cde0b..46bbadc 100644 --- a/tests/dialects/test_sqlite.py +++ b/tests/dialects/test_sqlite.py @@ -202,6 +202,7 @@ class TestSQLite(Validator): "CREATE TABLE z (a INTEGER UNIQUE PRIMARY KEY AUTOINCREMENT)", read={ "mysql": "CREATE TABLE z (a INT UNIQUE PRIMARY KEY AUTO_INCREMENT)", + "postgres": "CREATE TABLE z (a INT GENERATED BY DEFAULT AS IDENTITY NOT NULL UNIQUE PRIMARY KEY)", }, write={ "sqlite": "CREATE TABLE z (a INTEGER UNIQUE PRIMARY KEY AUTOINCREMENT)", diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py index 92adf7a..7455650 100644 --- a/tests/dialects/test_tsql.py +++ b/tests/dialects/test_tsql.py @@ -1,12 +1,18 @@ -from sqlglot import exp, parse, parse_one +from sqlglot import exp, parse from tests.dialects.test_dialect import Validator from sqlglot.errors import ParseError +from sqlglot.optimizer.annotate_types import annotate_types class TestTSQL(Validator): dialect = "tsql" def test_tsql(self): + self.assertEqual( + annotate_types(self.validate_identity("SELECT 1 WHERE EXISTS(SELECT 1)")).sql("tsql"), + "SELECT 1 WHERE EXISTS(SELECT 1)", + ) + self.validate_identity("CREATE view a.b.c", "CREATE VIEW b.c") self.validate_identity("DROP view a.b.c", "DROP VIEW b.c") self.validate_identity("ROUND(x, 1, 0)") @@ -217,9 +223,9 @@ class TestTSQL(Validator): "CREATE TABLE [db].[tbl] ([a] INTEGER)", ) - projection = parse_one("SELECT a = 1", read="tsql").selects[0] - projection.assert_is(exp.Alias) - projection.args["alias"].assert_is(exp.Identifier) + self.validate_identity("SELECT a = 1", "SELECT 1 AS a").selects[0].assert_is( + exp.Alias + ).args["alias"].assert_is(exp.Identifier) self.validate_all( "IF OBJECT_ID('tempdb.dbo.#TempTableName', 'U') IS NOT NULL DROP TABLE #TempTableName", @@ -756,12 +762,9 @@ class TestTSQL(Validator): for view_attr in ("ENCRYPTION", "SCHEMABINDING", "VIEW_METADATA"): self.validate_identity(f"CREATE VIEW a.b WITH {view_attr} AS SELECT * FROM x") - expression = parse_one("ALTER TABLE dbo.DocExe DROP CONSTRAINT FK_Column_B", dialect="tsql") - self.assertIsInstance(expression, exp.AlterTable) - self.assertIsInstance(expression.args["actions"][0], exp.Drop) - self.assertEqual( - expression.sql(dialect="tsql"), "ALTER TABLE dbo.DocExe DROP CONSTRAINT FK_Column_B" - ) + self.validate_identity("ALTER TABLE dbo.DocExe DROP CONSTRAINT FK_Column_B").assert_is( + exp.AlterTable + ).args["actions"][0].assert_is(exp.Drop) for clustered_keyword in ("CLUSTERED", "NONCLUSTERED"): self.validate_identity( @@ -795,10 +798,10 @@ class TestTSQL(Validator): ) self.validate_all( - "CREATE TABLE [#temptest] (name VARCHAR)", + "CREATE TABLE [#temptest] (name INTEGER)", read={ - "duckdb": "CREATE TEMPORARY TABLE 'temptest' (name VARCHAR)", - "tsql": "CREATE TABLE [#temptest] (name VARCHAR)", + "duckdb": "CREATE TEMPORARY TABLE 'temptest' (name INTEGER)", + "tsql": "CREATE TABLE [#temptest] (name INTEGER)", }, ) self.validate_all( @@ -1632,27 +1635,23 @@ WHERE ) def test_identifier_prefixes(self): - expr = parse_one("#x", read="tsql") - self.assertIsInstance(expr, exp.Column) - self.assertIsInstance(expr.this, exp.Identifier) - self.assertTrue(expr.this.args.get("temporary")) - self.assertEqual(expr.sql("tsql"), "#x") - - expr = parse_one("##x", read="tsql") - self.assertIsInstance(expr, exp.Column) - self.assertIsInstance(expr.this, exp.Identifier) - self.assertTrue(expr.this.args.get("global")) - self.assertEqual(expr.sql("tsql"), "##x") - - expr = parse_one("@x", read="tsql") - self.assertIsInstance(expr, exp.Parameter) - self.assertIsInstance(expr.this, exp.Var) - self.assertEqual(expr.sql("tsql"), "@x") + self.assertTrue( + self.validate_identity("#x") + .assert_is(exp.Column) + .this.assert_is(exp.Identifier) + .args.get("temporary") + ) + self.assertTrue( + self.validate_identity("##x") + .assert_is(exp.Column) + .this.assert_is(exp.Identifier) + .args.get("global") + ) - table = parse_one("select * from @x", read="tsql").args["from"].this - self.assertIsInstance(table, exp.Table) - self.assertIsInstance(table.this, exp.Parameter) - self.assertIsInstance(table.this.this, exp.Var) + self.validate_identity("@x").assert_is(exp.Parameter).this.assert_is(exp.Var) + self.validate_identity("SELECT * FROM @x").args["from"].this.assert_is( + exp.Table + ).this.assert_is(exp.Parameter).this.assert_is(exp.Var) self.validate_all( "SELECT @x", @@ -1663,8 +1662,6 @@ WHERE "tsql": "SELECT @x", }, ) - - def test_temp_table(self): self.validate_all( "SELECT * FROM #mytemptable", write={ diff --git a/tests/fixtures/identity.sql b/tests/fixtures/identity.sql index e31031d..4dc4aa1 100644 --- a/tests/fixtures/identity.sql +++ b/tests/fixtures/identity.sql @@ -872,3 +872,4 @@ SELECT name SELECT copy SELECT rollup SELECT unnest +SELECT * FROM a STRAIGHT_JOIN b diff --git a/tests/fixtures/optimizer/simplify.sql b/tests/fixtures/optimizer/simplify.sql index 87b42d1..6035ee6 100644 --- a/tests/fixtures/optimizer/simplify.sql +++ b/tests/fixtures/optimizer/simplify.sql @@ -1047,6 +1047,9 @@ x < CAST('2021-01-02' AS DATE) AND x >= CAST('2021-01-01' AS DATE); TIMESTAMP_TRUNC(x, YEAR) = CAST(CAST('2021-01-01 01:02:03' AS DATE) AS DATETIME); x < CAST('2022-01-01 00:00:00' AS DATETIME) AND x >= CAST('2021-01-01 00:00:00' AS DATETIME); +DATE_TRUNC('day', CAST(x AS DATE)) <= CAST('2021-01-01 01:02:03' AS TIMESTAMP); +CAST(x AS DATE) < CAST('2021-01-02 01:02:03' AS TIMESTAMP); + -------------------------------------- -- EQUALITY -------------------------------------- diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index 41a5015..81b9731 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -29,7 +29,11 @@ def parse_and_optimize(func, sql, read_dialect, **kwargs): def qualify_columns(expression, **kwargs): expression = optimizer.qualify.qualify( - expression, infer_schema=True, validate_qualify_columns=False, identify=False, **kwargs + expression, + infer_schema=True, + validate_qualify_columns=False, + identify=False, + **kwargs, ) return expression @@ -111,7 +115,14 @@ class TestOptimizer(unittest.TestCase): } def check_file( - self, file, func, pretty=False, execute=False, set_dialect=False, only=None, **kwargs + self, + file, + func, + pretty=False, + execute=False, + set_dialect=False, + only=None, + **kwargs, ): with ProcessPoolExecutor() as pool: results = {} @@ -331,7 +342,11 @@ class TestOptimizer(unittest.TestCase): ) self.check_file( - "qualify_columns", qualify_columns, execute=True, schema=self.schema, set_dialect=True + "qualify_columns", + qualify_columns, + execute=True, + schema=self.schema, + set_dialect=True, ) self.check_file( "qualify_columns_ddl", qualify_columns, schema=self.schema, set_dialect=True @@ -343,7 +358,8 @@ class TestOptimizer(unittest.TestCase): def test_pushdown_cte_alias_columns(self): self.check_file( - "pushdown_cte_alias_columns", optimizer.qualify_columns.pushdown_cte_alias_columns + "pushdown_cte_alias_columns", + optimizer.qualify_columns.pushdown_cte_alias_columns, ) def test_qualify_columns__invalid(self): @@ -405,7 +421,8 @@ class TestOptimizer(unittest.TestCase): self.assertEqual(optimizer.simplify.gen(query), optimizer.simplify.gen(query.copy())) anon_unquoted_identifier = exp.Anonymous( - this=exp.to_identifier("anonymous"), expressions=[exp.column("x"), exp.column("y")] + this=exp.to_identifier("anonymous"), + expressions=[exp.column("x"), exp.column("y")], ) self.assertEqual(optimizer.simplify.gen(anon_unquoted_identifier), "ANONYMOUS(x,y)") @@ -416,7 +433,10 @@ class TestOptimizer(unittest.TestCase): anon_invalid = exp.Anonymous(this=5) optimizer.simplify.gen(anon_invalid) - self.assertIn("Anonymous.this expects a str or an Identifier, got 'int'.", str(e.exception)) + self.assertIn( + "Anonymous.this expects a str or an Identifier, got 'int'.", + str(e.exception), + ) sql = parse_one( """ @@ -906,7 +926,8 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') # Check that x.cola AS cola and y.colb AS colb have types CHAR and TEXT, respectively for d, t in zip( - cte_select.find_all(exp.Subquery), [exp.DataType.Type.CHAR, exp.DataType.Type.TEXT] + cte_select.find_all(exp.Subquery), + [exp.DataType.Type.CHAR, exp.DataType.Type.TEXT], ): self.assertEqual(d.this.expressions[0].this.type.this, t) @@ -1020,7 +1041,8 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') for (func, col), target_type in tests.items(): expression = annotate_types( - parse_one(f"SELECT {func}(x.{col}) AS _col_0 FROM x AS x"), schema=schema + parse_one(f"SELECT {func}(x.{col}) AS _col_0 FROM x AS x"), + schema=schema, ) self.assertEqual(expression.expressions[0].type.this, target_type) @@ -1035,7 +1057,13 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') self.assertEqual(exp.DataType.Type.INT, expression.selects[1].type.this) def test_nested_type_annotation(self): - schema = {"order": {"customer_id": "bigint", "item_id": "bigint", "item_price": "numeric"}} + schema = { + "order": { + "customer_id": "bigint", + "item_id": "bigint", + "item_price": "numeric", + } + } sql = """ SELECT ARRAY_AGG(DISTINCT order.item_id) FILTER (WHERE order.item_price > 10) AS items, FROM order AS order @@ -1057,7 +1085,8 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') self.assertEqual(expression.selects[0].type.sql(dialect="bigquery"), "STRUCT<`f` STRING>") self.assertEqual( - expression.selects[1].type.sql(dialect="bigquery"), "ARRAY>" + expression.selects[1].type.sql(dialect="bigquery"), + "ARRAY>", ) expression = annotate_types( @@ -1206,7 +1235,8 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') self.assertEqual( optimizer.optimize( - parse_one("SELECT * FROM a"), schema=MappingSchema(schema, dialect="bigquery") + parse_one("SELECT * FROM a"), + schema=MappingSchema(schema, dialect="bigquery"), ), parse_one('SELECT "a"."a" AS "a", "a"."b" AS "b" FROM "a" AS "a"'), ) diff --git a/tests/test_parser.py b/tests/test_parser.py index 2cefc07..d6849c3 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -106,6 +106,7 @@ class TestParser(unittest.TestCase): expr = parse_one("SELECT foo IN UNNEST(bla) AS bar") self.assertIsInstance(expr.selects[0], exp.Alias) self.assertEqual(expr.selects[0].output_name, "bar") + self.assertIsNotNone(parse_one("select unnest(x)").find(exp.Unnest)) def test_unary_plus(self): self.assertEqual(parse_one("+15"), exp.Literal.number(15)) @@ -880,10 +881,12 @@ class TestParser(unittest.TestCase): self.assertIsInstance(parse_one("a IS DISTINCT FROM b OR c IS DISTINCT FROM d"), exp.Or) def test_trailing_comments(self): - expressions = parse(""" + expressions = parse( + """ select * from x; -- my comment - """) + """ + ) self.assertEqual( ";\n".join(e.sql() for e in expressions), "SELECT * FROM x;\n/* my comment */" -- cgit v1.2.3