diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-02-19 13:44:59 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-02-19 13:44:59 +0000 |
commit | ef2db38de92f2329c1c366318bddfc7e3dee8415 (patch) | |
tree | dee41de1eb0e05f2f6805b77df41a71b3aa66ec2 /tests | |
parent | Adding upstream version 11.0.1. (diff) | |
download | sqlglot-ef2db38de92f2329c1c366318bddfc7e3dee8415.tar.xz sqlglot-ef2db38de92f2329c1c366318bddfc7e3dee8415.zip |
Adding upstream version 11.1.3.upstream/11.1.3
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to '')
24 files changed, 349 insertions, 85 deletions
diff --git a/tests/dataframe/unit/test_functions.py b/tests/dataframe/unit/test_functions.py index 8b44b9f..f155065 100644 --- a/tests/dataframe/unit/test_functions.py +++ b/tests/dataframe/unit/test_functions.py @@ -1152,17 +1152,17 @@ class TestFunctions(unittest.TestCase): def test_regexp_extract(self): col_str = SF.regexp_extract("cola", r"(\d+)-(\d+)", 1) - self.assertEqual("REGEXP_EXTRACT(cola, '(\\\\d+)-(\\\\d+)', 1)", col_str.sql()) + self.assertEqual("REGEXP_EXTRACT(cola, '(\\d+)-(\\d+)', 1)", col_str.sql()) col = SF.regexp_extract(SF.col("cola"), r"(\d+)-(\d+)", 1) - self.assertEqual("REGEXP_EXTRACT(cola, '(\\\\d+)-(\\\\d+)', 1)", col.sql()) + self.assertEqual("REGEXP_EXTRACT(cola, '(\\d+)-(\\d+)', 1)", col.sql()) col_no_idx = SF.regexp_extract(SF.col("cola"), r"(\d+)-(\d+)") - self.assertEqual("REGEXP_EXTRACT(cola, '(\\\\d+)-(\\\\d+)')", col_no_idx.sql()) + self.assertEqual("REGEXP_EXTRACT(cola, '(\\d+)-(\\d+)')", col_no_idx.sql()) def test_regexp_replace(self): col_str = SF.regexp_replace("cola", r"(\d+)", "--") - self.assertEqual("REGEXP_REPLACE(cola, '(\\\\d+)', '--')", col_str.sql()) + self.assertEqual("REGEXP_REPLACE(cola, '(\\d+)', '--')", col_str.sql()) col = SF.regexp_replace(SF.col("cola"), r"(\d+)", "--") - self.assertEqual("REGEXP_REPLACE(cola, '(\\\\d+)', '--')", col.sql()) + self.assertEqual("REGEXP_REPLACE(cola, '(\\d+)', '--')", col.sql()) def test_initcap(self): col_str = SF.initcap("cola") diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py index 241f496..7b18a6a 100644 --- a/tests/dialects/test_bigquery.py +++ b/tests/dialects/test_bigquery.py @@ -36,30 +36,33 @@ class TestBigQuery(Validator): self.validate_all( r'r"""/\*.*\*/"""', write={ - "bigquery": r"'/\\*.*\\*/'", + "bigquery": r"'/\*.*\*/'", "duckdb": r"'/\*.*\*/'", "presto": r"'/\*.*\*/'", - "hive": r"'/\\*.*\\*/'", - "spark": r"'/\\*.*\\*/'", + "hive": r"'/\*.*\*/'", + "spark": r"'/\*.*\*/'", }, ) + with self.assertRaises(RuntimeError): + transpile("'\\'", read="bigquery") + self.validate_all( - r"'\\'", + "'\\\\'", write={ "bigquery": r"'\\'", - "duckdb": r"'\'", - "presto": r"'\'", + "duckdb": r"'\\'", + "presto": r"'\\'", "hive": r"'\\'", }, ) self.validate_all( - R'R"""/\*.*\*/"""', + r'R"""/\*.*\*/"""', write={ - "bigquery": R"'/\\*.*\\*/'", - "duckdb": R"'/\*.*\*/'", - "presto": R"'/\*.*\*/'", - "hive": R"'/\\*.*\\*/'", - "spark": R"'/\\*.*\\*/'", + "bigquery": r"'/\*.*\*/'", + "duckdb": r"'/\*.*\*/'", + "presto": r"'/\*.*\*/'", + "hive": r"'/\*.*\*/'", + "spark": r"'/\*.*\*/'", }, ) self.validate_all( @@ -228,6 +231,12 @@ class TestBigQuery(Validator): }, ) self.validate_all( + "CREATE TABLE db.example_table (x int) PARTITION BY x cluster by x", + write={ + "bigquery": "CREATE TABLE db.example_table (x INT64) PARTITION BY x CLUSTER BY x", + }, + ) + self.validate_all( "SELECT * FROM a WHERE b IN UNNEST([1, 2, 3])", write={ "bigquery": "SELECT * FROM a WHERE b IN UNNEST([1, 2, 3])", @@ -324,6 +333,12 @@ class TestBigQuery(Validator): "bigquery": "SELECT * FROM (SELECT a, b, c FROM test) PIVOT(SUM(b) AS d, COUNT(*) AS e FOR c IN ('x', 'y'))", }, ) + self.validate_all( + "SELECT REGEXP_EXTRACT(abc, 'pattern(group)') FROM table", + write={ + "duckdb": "SELECT REGEXP_EXTRACT(abc, 'pattern(group)', 1) FROM table", + }, + ) self.validate_identity("BEGIN A B C D E F") self.validate_identity("BEGIN TRANSACTION") self.validate_identity("COMMIT TRANSACTION") diff --git a/tests/dialects/test_databricks.py b/tests/dialects/test_databricks.py index 5ae5c6f..48ea6d1 100644 --- a/tests/dialects/test_databricks.py +++ b/tests/dialects/test_databricks.py @@ -6,6 +6,8 @@ class TestDatabricks(Validator): def test_databricks(self): self.validate_identity("CREATE FUNCTION a.b(x INT) RETURNS INT RETURN x + 1") + self.validate_identity("CREATE FUNCTION a AS b") + self.validate_identity("SELECT ${x} FROM ${y} WHERE ${z} > 1") def test_datediff(self): self.validate_all( diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py index 442fbbb..5f048da 100644 --- a/tests/dialects/test_dialect.py +++ b/tests/dialects/test_dialect.py @@ -573,19 +573,33 @@ class TestDialect(Validator): self.validate_all( "DATE_TRUNC('year', x)", read={ + "bigquery": "DATE_TRUNC(x, year)", "starrocks": "DATE_TRUNC('year', x)", + "spark": "TRUNC(x, 'year')", }, write={ + "bigquery": "DATE_TRUNC(x, year)", + "mysql": "STR_TO_DATE(CONCAT(YEAR(x), ' 1 1'), '%Y %c %e')", + "postgres": "DATE_TRUNC('year', x)", "starrocks": "DATE_TRUNC('year', x)", + "spark": "TRUNC(x, 'year')", }, ) self.validate_all( - "DATE_TRUNC(x, year)", + "TIMESTAMP_TRUNC(x, year)", read={ - "bigquery": "DATE_TRUNC(x, year)", + "bigquery": "TIMESTAMP_TRUNC(x, year)", + "spark": "DATE_TRUNC('year', x)", }, write={ - "bigquery": "DATE_TRUNC(x, year)", + "bigquery": "TIMESTAMP_TRUNC(x, year)", + "spark": "DATE_TRUNC('year', x)", + }, + ) + self.validate_all( + "DATE_TRUNC('millenium', x)", + write={ + "mysql": UnsupportedError, }, ) self.validate_all( diff --git a/tests/dialects/test_drill.py b/tests/dialects/test_drill.py index 9819daa..a196013 100644 --- a/tests/dialects/test_drill.py +++ b/tests/dialects/test_drill.py @@ -34,11 +34,11 @@ class TestDrill(Validator): self.validate_all( "'\\\\a'", read={ - "presto": "'\\a'", + "presto": "'\\\\a'", }, write={ - "duckdb": "'\\a'", - "presto": "'\\a'", + "duckdb": "'\\\\a'", + "presto": "'\\\\a'", "hive": "'\\\\a'", "spark": "'\\\\a'", }, diff --git a/tests/dialects/test_duckdb.py b/tests/dialects/test_duckdb.py index e5cb833..46e75c0 100644 --- a/tests/dialects/test_duckdb.py +++ b/tests/dialects/test_duckdb.py @@ -317,6 +317,8 @@ class TestDuckDB(Validator): }, ) + self.validate_identity("ATTACH DATABASE ':memory:' AS new_database") + with self.assertRaises(UnsupportedError): transpile( "SELECT a FROM b PIVOT(SUM(x) FOR y IN ('z', 'q'))", @@ -324,6 +326,14 @@ class TestDuckDB(Validator): unsupported_level=ErrorLevel.IMMEDIATE, ) + with self.assertRaises(UnsupportedError): + transpile( + "SELECT REGEXP_EXTRACT(a, 'pattern', 1) from table", + read="bigquery", + write="duckdb", + unsupported_level=ErrorLevel.IMMEDIATE, + ) + def test_array(self): self.validate_identity("ARRAY(SELECT id FROM t)") diff --git a/tests/dialects/test_hive.py b/tests/dialects/test_hive.py index 42d9943..a067764 100644 --- a/tests/dialects/test_hive.py +++ b/tests/dialects/test_hive.py @@ -214,11 +214,11 @@ class TestHive(Validator): self.validate_all( "'\\\\a'", read={ - "presto": "'\\a'", + "presto": "'\\\\a'", }, write={ - "duckdb": "'\\a'", - "presto": "'\\a'", + "duckdb": "'\\\\a'", + "presto": "'\\\\a'", "hive": "'\\\\a'", "spark": "'\\\\a'", }, @@ -345,13 +345,13 @@ class TestHive(Validator): "INSERT OVERWRITE TABLE zipcodes PARTITION(state = 0) VALUES (896, 'US', 'TAMPA', 33607)" ) self.validate_identity( - "SELECT a, b, SUM(c) FROM tabl AS t GROUP BY a, b GROUPING SETS ((a, b), a)" + "SELECT a, b, SUM(c) FROM tabl AS t GROUP BY a, b, GROUPING SETS ((a, b), a)" ) self.validate_identity( - "SELECT a, b, SUM(c) FROM tabl AS t GROUP BY a, b GROUPING SETS ((t.a, b), a)" + "SELECT a, b, SUM(c) FROM tabl AS t GROUP BY a, b, GROUPING SETS ((t.a, b), a)" ) self.validate_identity( - "SELECT a, b, SUM(c) FROM tabl AS t GROUP BY a, FOO(b) GROUPING SETS ((a, FOO(b)), a)" + "SELECT a, b, SUM(c) FROM tabl AS t GROUP BY a, FOO(b), GROUPING SETS ((a, FOO(b)), a)" ) self.validate_identity( "SELECT key, value, GROUPING__ID, COUNT(*) FROM T1 GROUP BY key, value WITH CUBE" @@ -648,8 +648,20 @@ class TestHive(Validator): }, ) self.validate_all( - "SELECT a, SUM(c) FROM t GROUP BY a, DATE_FORMAT(b, 'yyyy') GROUPING SETS ((a, DATE_FORMAT(b, 'yyyy')), a)", + "SELECT a, SUM(c) FROM t GROUP BY a, DATE_FORMAT(b, 'yyyy'), GROUPING SETS ((a, DATE_FORMAT(b, 'yyyy')), a)", write={ - "hive": "SELECT a, SUM(c) FROM t GROUP BY a, DATE_FORMAT(CAST(b AS TIMESTAMP), 'yyyy') GROUPING SETS ((a, DATE_FORMAT(CAST(b AS TIMESTAMP), 'yyyy')), a)", + "hive": "SELECT a, SUM(c) FROM t GROUP BY a, DATE_FORMAT(CAST(b AS TIMESTAMP), 'yyyy'), GROUPING SETS ((a, DATE_FORMAT(CAST(b AS TIMESTAMP), 'yyyy')), a)", }, ) + + def test_escapes(self) -> None: + self.validate_identity("'\n'") + self.validate_identity("'\\n'") + self.validate_identity("'\\\n'") + self.validate_identity("'\\\\n'") + self.validate_identity("''") + self.validate_identity("'\\\\'") + self.validate_identity("'\z'") + self.validate_identity("'\\z'") + self.validate_identity("'\\\z'") + self.validate_identity("'\\\\z'") diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py index 3e3b0d3..192f9fc 100644 --- a/tests/dialects/test_mysql.py +++ b/tests/dialects/test_mysql.py @@ -90,6 +90,14 @@ class TestMySQL(Validator): self.validate_identity("SELECT INSTR('str', 'substr')", "SELECT LOCATE('substr', 'str')") self.validate_identity("SELECT UCASE('foo')", "SELECT UPPER('foo')") self.validate_identity("SELECT LCASE('foo')", "SELECT LOWER('foo')") + self.validate_identity( + "SELECT DAY_OF_MONTH('2023-01-01')", "SELECT DAYOFMONTH('2023-01-01')" + ) + self.validate_identity("SELECT DAY_OF_WEEK('2023-01-01')", "SELECT DAYOFWEEK('2023-01-01')") + self.validate_identity("SELECT DAY_OF_YEAR('2023-01-01')", "SELECT DAYOFYEAR('2023-01-01')") + self.validate_identity( + "SELECT WEEK_OF_YEAR('2023-01-01')", "SELECT WEEKOFYEAR('2023-01-01')" + ) def test_escape(self): self.validate_all( @@ -249,26 +257,26 @@ class TestMySQL(Validator): "CREATE TABLE z (a INT) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARACTER SET=utf8 COLLATE=utf8_bin COMMENT='x'" ) self.validate_identity( - "CREATE TABLE z (a INT DEFAULT NULL, PRIMARY KEY(a)) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARACTER SET=utf8 COLLATE=utf8_bin COMMENT='x'" + "CREATE TABLE z (a INT DEFAULT NULL, PRIMARY KEY (a)) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARACTER SET=utf8 COLLATE=utf8_bin COMMENT='x'" ) self.validate_all( """ CREATE TABLE `t_customer_account` ( - "id" int(11) NOT NULL AUTO_INCREMENT, - "customer_id" int(11) DEFAULT NULL COMMENT '客户id', - "bank" varchar(100) COLLATE utf8_bin DEFAULT NULL COMMENT '行别', - "account_no" varchar(100) COLLATE utf8_bin DEFAULT NULL COMMENT '账号', - PRIMARY KEY ("id") + `id` int(11) NOT NULL AUTO_INCREMENT, + `customer_id` int(11) DEFAULT NULL COMMENT '客户id', + `bank` varchar(100) COLLATE utf8_bin DEFAULT NULL COMMENT '行别', + `account_no` varchar(100) COLLATE utf8_bin DEFAULT NULL COMMENT '账号', + PRIMARY KEY (`id`) ) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARACTER SET=utf8 COLLATE=utf8_bin COMMENT='客户账户表' """, write={ "mysql": """CREATE TABLE `t_customer_account` ( - 'id' INT(11) NOT NULL AUTO_INCREMENT, - 'customer_id' INT(11) DEFAULT NULL COMMENT '客户id', - 'bank' VARCHAR(100) COLLATE utf8_bin DEFAULT NULL COMMENT '行别', - 'account_no' VARCHAR(100) COLLATE utf8_bin DEFAULT NULL COMMENT '账号', - PRIMARY KEY('id') + `id` INT(11) NOT NULL AUTO_INCREMENT, + `customer_id` INT(11) DEFAULT NULL COMMENT '客户id', + `bank` VARCHAR(100) COLLATE utf8_bin DEFAULT NULL COMMENT '行别', + `account_no` VARCHAR(100) COLLATE utf8_bin DEFAULT NULL COMMENT '账号', + PRIMARY KEY (`id`) ) ENGINE=InnoDB AUTO_INCREMENT=1 diff --git a/tests/dialects/test_oracle.py b/tests/dialects/test_oracle.py index 1fadb84..f85a117 100644 --- a/tests/dialects/test_oracle.py +++ b/tests/dialects/test_oracle.py @@ -2,5 +2,68 @@ from tests.dialects.test_dialect import Validator class TestOracle(Validator): + dialect = "oracle" + def test_oracle(self): self.validate_identity("SELECT * FROM V$SESSION") + + def test_xml_table(self): + self.validate_identity("XMLTABLE('x')") + self.validate_identity("XMLTABLE('x' RETURNING SEQUENCE BY REF)") + self.validate_identity("XMLTABLE('x' PASSING y)") + self.validate_identity("XMLTABLE('x' PASSING y RETURNING SEQUENCE BY REF)") + self.validate_identity( + "XMLTABLE('x' RETURNING SEQUENCE BY REF COLUMNS a VARCHAR2, b FLOAT)" + ) + + self.validate_all( + """SELECT warehouse_name warehouse, + warehouse2."Water", warehouse2."Rail" + FROM warehouses, + XMLTABLE('/Warehouse' + PASSING warehouses.warehouse_spec + COLUMNS + "Water" varchar2(6) PATH 'WaterAccess', + "Rail" varchar2(6) PATH 'RailAccess') + warehouse2""", + write={ + "oracle": """SELECT + warehouse_name AS warehouse, + warehouse2."Water", + warehouse2."Rail" +FROM warehouses, XMLTABLE( + '/Warehouse' + PASSING + warehouses.warehouse_spec + COLUMNS + "Water" VARCHAR2(6) PATH 'WaterAccess', + "Rail" VARCHAR2(6) PATH 'RailAccess' +) warehouse2""", + }, + pretty=True, + ) + + self.validate_all( + """SELECT table_name, column_name, data_default FROM xmltable('ROWSET/ROW' + passing dbms_xmlgen.getxmltype('SELECT table_name, column_name, data_default FROM user_tab_columns') + columns table_name VARCHAR2(128) PATH '*[1]' + , column_name VARCHAR2(128) PATH '*[2]' + , data_default VARCHAR2(2000) PATH '*[3]' + );""", + write={ + "oracle": """SELECT + table_name, + column_name, + data_default +FROM XMLTABLE( + 'ROWSET/ROW' + PASSING + dbms_xmlgen.getxmltype ("SELECT table_name, column_name, data_default FROM user_tab_columns") + COLUMNS + table_name VARCHAR2(128) PATH '*[1]', + column_name VARCHAR2(128) PATH '*[2]', + data_default VARCHAR2(2000) PATH '*[3]' +)""", + }, + pretty=True, + ) diff --git a/tests/dialects/test_postgres.py b/tests/dialects/test_postgres.py index 5664a2a..f0117bc 100644 --- a/tests/dialects/test_postgres.py +++ b/tests/dialects/test_postgres.py @@ -10,6 +10,13 @@ class TestPostgres(Validator): self.validate_identity("CREATE TABLE test (foo HSTORE)") self.validate_identity("CREATE TABLE test (foo JSONB)") self.validate_identity("CREATE TABLE test (foo VARCHAR(64)[])") + + self.validate_all( + "CREATE OR REPLACE FUNCTION function_name (input_a character varying DEFAULT NULL::character varying)", + write={ + "postgres": "CREATE OR REPLACE FUNCTION function_name(input_a VARCHAR DEFAULT CAST(NULL AS VARCHAR))", + }, + ) self.validate_all( "CREATE TABLE products (product_no INT UNIQUE, name TEXT, price DECIMAL)", write={ @@ -56,20 +63,7 @@ class TestPostgres(Validator): ) def test_postgres(self): - self.validate_all( - "x ^ y", - write={ - "": "POWER(x, y)", - "postgres": "x ^ y", - }, - ) - self.validate_all( - "x # y", - write={ - "": "x ^ y", - "postgres": "x # y", - }, - ) + self.validate_identity("$x") self.validate_identity("SELECT ARRAY[1, 2, 3]") self.validate_identity("SELECT ARRAY(SELECT 1)") self.validate_identity("SELECT ARRAY_LENGTH(ARRAY[1, 2, 3], 1)") @@ -113,6 +107,20 @@ class TestPostgres(Validator): self.validate_identity("x ~* 'y'") self.validate_all( + "x ^ y", + write={ + "": "POWER(x, y)", + "postgres": "x ^ y", + }, + ) + self.validate_all( + "x # y", + write={ + "": "x ^ y", + "postgres": "x # y", + }, + ) + self.validate_all( "GENERATE_SERIES(a, b, ' 2 days ')", write={ "postgres": "GENERATE_SERIES(a, b, INTERVAL '2' days)", diff --git a/tests/dialects/test_presto.py b/tests/dialects/test_presto.py index 9815dcc..bf22652 100644 --- a/tests/dialects/test_presto.py +++ b/tests/dialects/test_presto.py @@ -366,6 +366,12 @@ class TestPresto(Validator): self.validate_identity("APPROX_PERCENTILE(a, b, c, d)") self.validate_all( + "SELECT a FROM t GROUP BY a, ROLLUP(b), ROLLUP(c), ROLLUP(d)", + write={ + "presto": "SELECT a FROM t GROUP BY a, ROLLUP (b, c, d)", + }, + ) + self.validate_all( 'SELECT a."b" FROM "foo"', write={ "duckdb": 'SELECT a."b" FROM "foo"', @@ -507,6 +513,14 @@ class TestPresto(Validator): }, ) + self.validate_all( + "SELECT a, b, c, d, sum(y) FROM z GROUP BY CUBE(a) ROLLUP(a), GROUPING SETS((b, c)), d", + write={ + "presto": "SELECT a, b, c, d, SUM(y) FROM z GROUP BY d, GROUPING SETS ((b, c)), CUBE (a), ROLLUP (a)", + "hive": "SELECT a, b, c, d, SUM(y) FROM z GROUP BY d, GROUPING SETS ((b, c)), CUBE (a), ROLLUP (a)", + }, + ) + def test_encode_decode(self): self.validate_all( "TO_UTF8(x)", diff --git a/tests/dialects/test_redshift.py b/tests/dialects/test_redshift.py index e20661e..fa4d422 100644 --- a/tests/dialects/test_redshift.py +++ b/tests/dialects/test_redshift.py @@ -47,7 +47,7 @@ class TestRedshift(Validator): self.validate_all( 'SELECT tablename, "column" FROM pg_table_def WHERE "column" LIKE \'%start\\_%\' LIMIT 5', write={ - "redshift": 'SELECT tablename, "column" FROM pg_table_def WHERE "column" LIKE \'%start\\\\_%\' LIMIT 5' + "redshift": 'SELECT tablename, "column" FROM pg_table_def WHERE "column" LIKE \'%start\\_%\' LIMIT 5' }, ) self.validate_all( @@ -72,6 +72,13 @@ class TestRedshift(Validator): "postgres": "COALESCE(a, b, c, d)", }, ) + self.validate_all( + "DATEDIFF(d, a, b)", + write={ + "redshift": "DATEDIFF(d, a, b)", + "presto": "DATE_DIFF(d, a, b)", + }, + ) def test_identity(self): self.validate_identity( diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py index 201cc4e..9e22527 100644 --- a/tests/dialects/test_snowflake.py +++ b/tests/dialects/test_snowflake.py @@ -1,4 +1,4 @@ -from sqlglot import UnsupportedError +from sqlglot import UnsupportedError, exp, parse_one from tests.dialects.test_dialect import Validator @@ -6,6 +6,7 @@ class TestSnowflake(Validator): dialect = "snowflake" def test_snowflake(self): + self.validate_identity("$x") self.validate_identity("SELECT REGEXP_LIKE(a, b, c)") self.validate_identity("PUT file:///dir/tmp.csv @%table") self.validate_identity("CREATE TABLE foo (bar FLOAT AUTOINCREMENT START 0 INCREMENT 1)") @@ -202,10 +203,10 @@ class TestSnowflake(Validator): self.validate_all( r"SELECT $$a ' \ \t \x21 z $ $$", write={ - "snowflake": r"SELECT 'a \' \\ \t \\x21 z $ '", + "snowflake": r"SELECT 'a \' \ \t \x21 z $ '", }, ) - self.validate_identity(r"REGEXP_REPLACE('target', 'pattern', '\n')") + self.validate_identity("REGEXP_REPLACE('target', 'pattern', '\n')") self.validate_all( "SELECT RLIKE(a, b)", write={ @@ -612,6 +613,13 @@ FROM persons AS p, LATERAL FLATTEN(input => p.c, path => 'contact') AS f, LATERA }, ) + def test_parse_like_any(self): + like = parse_one("a LIKE ANY fun('foo')", read="snowflake") + ilike = parse_one("a ILIKE ANY fun('foo')", read="snowflake") + + self.assertIsInstance(like, exp.LikeAny) + self.assertIsInstance(ilike, exp.ILikeAny) + def test_match_recognize(self): for row in ( "ONE ROW PER MATCH", diff --git a/tests/dialects/test_sqlite.py b/tests/dialects/test_sqlite.py index c4f4a6e..19a88f3 100644 --- a/tests/dialects/test_sqlite.py +++ b/tests/dialects/test_sqlite.py @@ -5,6 +5,12 @@ class TestSQLite(Validator): dialect = "sqlite" def test_ddl(self): + self.validate_identity("INSERT OR ABORT INTO foo (x, y) VALUES (1, 2)") + self.validate_identity("INSERT OR FAIL INTO foo (x, y) VALUES (1, 2)") + self.validate_identity("INSERT OR IGNORE INTO foo (x, y) VALUES (1, 2)") + self.validate_identity("INSERT OR REPLACE INTO foo (x, y) VALUES (1, 2)") + self.validate_identity("INSERT OR ROLLBACK INTO foo (x, y) VALUES (1, 2)") + self.validate_all( "CREATE TABLE foo (id INTEGER PRIMARY KEY ASC)", write={"sqlite": "CREATE TABLE foo (id INTEGER PRIMARY KEY ASC)"}, diff --git a/tests/dialects/test_teradata.py b/tests/dialects/test_teradata.py index 9e82961..ab87eef 100644 --- a/tests/dialects/test_teradata.py +++ b/tests/dialects/test_teradata.py @@ -24,3 +24,41 @@ class TestTeradata(Validator): def test_create(self): self.validate_identity("CREATE TABLE x (y INT) PRIMARY INDEX (y) PARTITION BY y INDEX (y)") + + self.validate_all( + "REPLACE VIEW a AS (SELECT b FROM c)", + write={"teradata": "CREATE OR REPLACE VIEW a AS (SELECT b FROM c)"}, + ) + + self.validate_all( + "SEL a FROM b", + write={"teradata": "SELECT a FROM b"}, + ) + + def test_insert(self): + self.validate_all( + "INS INTO x SELECT * FROM y", write={"teradata": "INSERT INTO x SELECT * FROM y"} + ) + + def test_mod(self): + self.validate_all("a MOD b", write={"teradata": "a MOD b", "mysql": "a % b"}) + + def test_abbrev(self): + self.validate_all("a LT b", write={"teradata": "a < b"}) + self.validate_all("a LE b", write={"teradata": "a <= b"}) + self.validate_all("a GT b", write={"teradata": "a > b"}) + self.validate_all("a GE b", write={"teradata": "a >= b"}) + self.validate_all("a ^= b", write={"teradata": "a <> b"}) + self.validate_all("a NE b", write={"teradata": "a <> b"}) + self.validate_all("a NOT= b", write={"teradata": "a <> b"}) + + def test_datatype(self): + self.validate_all( + "CREATE TABLE z (a ST_GEOMETRY(1))", + write={ + "teradata": "CREATE TABLE z (a ST_GEOMETRY(1))", + "redshift": "CREATE TABLE z (a GEOMETRY(1))", + }, + ) + + self.validate_identity("CREATE TABLE z (a SYSUDTLIB.INT)") diff --git a/tests/fixtures/identity.sql b/tests/fixtures/identity.sql index b3f546b..7b9ae6d 100644 --- a/tests/fixtures/identity.sql +++ b/tests/fixtures/identity.sql @@ -16,6 +16,18 @@ SUM(CASE WHEN x > 1 THEN 1 ELSE 0 END) / y 'x' '\x' "x" +'\z' +'\\z' +'\\\z' +'\\\\z' +'\\\\\z' +'\\\\\\z' +'\n' +'\\n' +'\\\n' +'\\\\n' +'\\\\\n' +'\\\\\\n' "" """x""" N'abc' @@ -502,7 +514,7 @@ CREATE TABLE z (a INT(11) DEFAULT NULL COMMENT '客户id') CREATE TABLE z (a INT(11) NOT NULL DEFAULT 1) CREATE TABLE z (a INT(11) NOT NULL DEFAULT -1) CREATE TABLE z (a INT(11) NOT NULL COLLATE utf8_bin AUTO_INCREMENT) -CREATE TABLE z (a INT, PRIMARY KEY(a)) +CREATE TABLE z (a INT, PRIMARY KEY (a)) CREATE TABLE z WITH (FORMAT='parquet') AS SELECT 1 CREATE TABLE z WITH (FORMAT='ORC', x='2') AS SELECT 1 CREATE TABLE z WITH (TABLE_FORMAT='iceberg', FORMAT='parquet') AS SELECT 1 @@ -530,9 +542,13 @@ CREATE TABLE asd AS SELECT asd FROM asd WITH DATA CREATE TABLE products (x INT GENERATED BY DEFAULT AS IDENTITY) CREATE TABLE products (x INT GENERATED ALWAYS AS IDENTITY) CREATE TABLE IF NOT EXISTS customer (pk BIGINT NOT NULL GENERATED ALWAYS AS IDENTITY (INCREMENT BY 1)) -CREATE TABLE customer (pk BIGINT NOT NULL GENERATED ALWAYS AS IDENTITY (START WITH 10 INCREMENT BY 1)) +CREATE TABLE customer (pk BIGINT NOT NULL GENERATED ALWAYS AS IDENTITY (START WITH 10 INCREMENT BY 1 MINVALUE -1 MAXVALUE 1 NO CYCLE)) CREATE TABLE customer (pk BIGINT NOT NULL GENERATED ALWAYS AS IDENTITY (START WITH 10)) +CREATE TABLE customer (pk BIGINT NOT NULL GENERATED ALWAYS AS IDENTITY (CYCLE)) CREATE TABLE foo (baz_id INT REFERENCES baz(id) DEFERRABLE) +CREATE TABLE foo (baz CHAR(4) CHARACTER SET LATIN UPPERCASE NOT CASESPECIFIC) +CREATE TABLE foo (baz DATE FORMAT 'YYYY/MM/DD' TITLE 'title') +CREATE TABLE t (title TEXT) CREATE TABLE a, FALLBACK, LOG, JOURNAL, CHECKSUM=DEFAULT, DEFAULT MERGEBLOCKRATIO, BLOCKCOMPRESSION=MANUAL (a INT) CREATE TABLE a, NO FALLBACK PROTECTION, NO LOG, NO JOURNAL, CHECKSUM=ON, NO MERGEBLOCKRATIO, BLOCKCOMPRESSION=ALWAYS (a INT) CREATE TABLE a, WITH JOURNAL TABLE=x.y.z, CHECKSUM=OFF, MERGEBLOCKRATIO=1, DATABLOCKSIZE=10 KBYTES (a INT) @@ -556,6 +572,7 @@ CREATE TEMPORARY VIEW x AS WITH y AS (SELECT 1) SELECT * FROM y CREATE MATERIALIZED VIEW x.y.z AS SELECT a FROM b CREATE VIEW z (a, b) CREATE VIEW z (a, b COMMENT 'b', c COMMENT 'c') +CREATE VIEW z AS LOCKING ROW FOR ACCESS SELECT a FROM b CREATE TEMPORARY FUNCTION f CREATE TEMPORARY FUNCTION f AS 'g' CREATE FUNCTION f @@ -731,3 +748,5 @@ SELECT id FROM b.a AS a QUALIFY ROW_NUMBER() OVER (PARTITION BY br ORDER BY sadf SELECT LEFT.FOO FROM BLA AS LEFT SELECT RIGHT.FOO FROM BLA AS RIGHT SELECT LEFT FROM LEFT LEFT JOIN RIGHT RIGHT JOIN LEFT +SELECT * FROM x WHERE name ILIKE ANY XXX('a', 'b') +SELECT * FROM x WHERE name LIKE ANY XXX('a', 'b') diff --git a/tests/fixtures/optimizer/optimizer.sql b/tests/fixtures/optimizer/optimizer.sql index 9c14ec1..6ccf24e 100644 --- a/tests/fixtures/optimizer/optimizer.sql +++ b/tests/fixtures/optimizer/optimizer.sql @@ -1,15 +1,10 @@ # title: lateral # execute: false SELECT a, m FROM z LATERAL VIEW EXPLODE([1, 2]) q AS m; -WITH "z_2" AS ( - SELECT - "z"."a" AS "a" - FROM "z" AS "z" -) SELECT "z"."a" AS "a", "q"."m" AS "m" -FROM "z_2" AS "z" +FROM "z" AS "z" LATERAL VIEW EXPLODE(ARRAY(1, 2)) q AS "m"; diff --git a/tests/fixtures/optimizer/pushdown_projections.sql b/tests/fixtures/optimizer/pushdown_projections.sql index 03ecf16..107e92f 100644 --- a/tests/fixtures/optimizer/pushdown_projections.sql +++ b/tests/fixtures/optimizer/pushdown_projections.sql @@ -57,3 +57,30 @@ SELECT q.x AS x FROM (VALUES (1, 2)) AS q(x, y); SELECT i.a FROM x AS i LEFT JOIN (SELECT a, b FROM (SELECT a, b FROM x)) AS j ON i.a = j.a; SELECT i.a AS a FROM x AS i LEFT JOIN (SELECT _q_0.a AS a FROM (SELECT x.a AS a FROM x AS x) AS _q_0) AS j ON i.a = j.a; + +-------------------------------------- +-- Unknown Star Expansion +-------------------------------------- +SELECT a FROM (SELECT * FROM zz) WHERE b = 1; +SELECT _q_0.a AS a FROM (SELECT zz.a AS a, zz.b AS b FROM zz AS zz) AS _q_0 WHERE _q_0.b = 1; + +SELECT a FROM (SELECT * FROM aa UNION ALL SELECT * FROM bb UNION ALL SELECT * from cc); +SELECT _q_0.a AS a FROM (SELECT aa.a AS a FROM aa AS aa UNION ALL SELECT bb.a AS a FROM bb AS bb UNION ALL SELECT cc.a AS a FROM cc AS cc) AS _q_0; + +SELECT a FROM (SELECT a FROM aa UNION ALL SELECT * FROM bb UNION ALL SELECT * from cc); +SELECT _q_0.a AS a FROM (SELECT aa.a AS a FROM aa AS aa UNION ALL SELECT bb.a AS a FROM bb AS bb UNION ALL SELECT cc.a AS a FROM cc AS cc) AS _q_0; + +SELECT a FROM (SELECT * FROM aa UNION ALL SELECT * FROM bb UNION ALL SELECT * from cc); +SELECT _q_0.a AS a FROM (SELECT aa.a AS a FROM aa AS aa UNION ALL SELECT bb.a AS a FROM bb AS bb UNION ALL SELECT cc.a AS a FROM cc AS cc) AS _q_0; + +SELECT a FROM (SELECT * FROM aa CROSS JOIN bb); +SELECT _q_0.a AS a FROM (SELECT a AS a FROM aa AS aa CROSS JOIN bb AS bb) AS _q_0; + +SELECT a FROM (SELECT aa.* FROM aa); +SELECT _q_0.a AS a FROM (SELECT aa.a AS a FROM aa AS aa) AS _q_0; + +SELECT a FROM (SELECT * FROM (SELECT * FROM aa)); +SELECT _q_1.a AS a FROM (SELECT _q_0.a AS a FROM (SELECT aa.a AS a FROM aa AS aa) AS _q_0) AS _q_1; + +with cte1 as (SELECT cola, colb FROM tb UNION ALL SELECT colc, cold FROM tb2) SELECT cola FROM cte1; +WITH cte1 AS (SELECT tb.cola AS cola FROM tb AS tb UNION ALL SELECT tb2.colc AS colc FROM tb2 AS tb2) SELECT cte1.cola AS cola FROM cte1; diff --git a/tests/fixtures/optimizer/qualify_columns.sql b/tests/fixtures/optimizer/qualify_columns.sql index 141f028..46c576a 100644 --- a/tests/fixtures/optimizer/qualify_columns.sql +++ b/tests/fixtures/optimizer/qualify_columns.sql @@ -215,6 +215,9 @@ SELECT _q_0.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x AS x) AS _q_0; SELECT * FROM (SELECT a FROM x); SELECT _q_0.a AS a FROM (SELECT x.a AS a FROM x AS x) AS _q_0; +SELECT * FROM x GROUP BY 1, 2; +SELECT x.a AS a, x.b AS b FROM x AS x GROUP BY x.a, x.b; + -------------------------------------- -- CTEs -------------------------------------- @@ -310,6 +313,15 @@ SELECT t.aa AS aa FROM x AS x, UNNEST(x.a) AS t(aa); SELECT aa FROM x, UNNEST(a) AS aa; SELECT aa AS aa FROM x AS x, UNNEST(x.a) AS aa; +# execute: false +# dialect: presto +SELECT x.a, i.b FROM x CROSS JOIN UNNEST(SPLIT(b, ',')) AS i(b); +SELECT x.a AS a, i.b AS b FROM x AS x CROSS JOIN UNNEST(SPLIT(x.b, ',')) AS i(b); + +# execute: false +SELECT c FROM (SELECT 1 a) AS x LATERAL VIEW EXPLODE(a) AS c; +SELECT _q_0.c AS c FROM (SELECT 1 AS a) AS x LATERAL VIEW EXPLODE(x.a) _q_0 AS c; + -------------------------------------- -- Window functions -------------------------------------- diff --git a/tests/fixtures/optimizer/qualify_columns__invalid.sql b/tests/fixtures/optimizer/qualify_columns__invalid.sql index 2a3ccfb..f0f9f87 100644 --- a/tests/fixtures/optimizer/qualify_columns__invalid.sql +++ b/tests/fixtures/optimizer/qualify_columns__invalid.sql @@ -1,4 +1,3 @@ -SELECT * FROM zz; SELECT z.a FROM x; SELECT z.* FROM x; SELECT x FROM x; @@ -11,3 +10,4 @@ SELECT b FROM x AS a CROSS JOIN y AS b CROSS JOIN y AS c; SELECT x.a FROM x JOIN y USING (a); SELECT a, SUM(b) FROM x GROUP BY 3; SELECT p FROM (SELECT x from xx) y CROSS JOIN yy CROSS JOIN zz +select a from (select * from x cross join y); diff --git a/tests/fixtures/pretty.sql b/tests/fixtures/pretty.sql index c67ba5d..a06af88 100644 --- a/tests/fixtures/pretty.sql +++ b/tests/fixtures/pretty.sql @@ -40,7 +40,7 @@ WITH cte1 AS ( FROM (SELECT 1) AS x, y, (SELECT 2) z UNION ALL SELECT MAX(COALESCE(x AND y, a and b and c, d and e)), FOO(CASE WHEN a and b THEN c and d ELSE 3 END) - GROUP BY x, GROUPING SETS (a, (b, c)) CUBE(y, z) + GROUP BY x, GROUPING SETS (a, (b, c)), CUBE(y, z) ) x ) SELECT a, b c FROM ( @@ -95,7 +95,7 @@ WITH cte1 AS ( MAX(COALESCE(x AND y, a AND b AND c, d AND e)), FOO(CASE WHEN a AND b THEN c AND d ELSE 3 END) GROUP BY - x + x, GROUPING SETS ( a, (b, c) diff --git a/tests/test_expressions.py b/tests/test_expressions.py index 7acc0fa..8b74fe1 100644 --- a/tests/test_expressions.py +++ b/tests/test_expressions.py @@ -406,6 +406,8 @@ class TestExpressions(unittest.TestCase): ) def test_functions(self): + self.assertIsInstance(parse_one("x LIKE ANY (y)"), exp.Like) + self.assertIsInstance(parse_one("x ILIKE ANY (y)"), exp.ILike) self.assertIsInstance(parse_one("ABS(a)"), exp.Abs) self.assertIsInstance(parse_one("APPROX_DISTINCT(a)"), exp.ApproxDistinct) self.assertIsInstance(parse_one("ARRAY(a)"), exp.Array) @@ -473,23 +475,24 @@ class TestExpressions(unittest.TestCase): self.assertIsInstance(parse_one("GENERATE_SERIES(a, b, c)"), exp.GenerateSeries) def test_column(self): - dot = parse_one("a.b.c") - column = dot.this - self.assertEqual(column.table, "a") - self.assertEqual(column.name, "b") - self.assertEqual(dot.text("expression"), "c") + column = parse_one("a.b.c.d") + self.assertEqual(column.catalog, "a") + self.assertEqual(column.db, "b") + self.assertEqual(column.table, "c") + self.assertEqual(column.name, "d") column = parse_one("a") self.assertEqual(column.name, "a") self.assertEqual(column.table, "") - fields = parse_one("a.b.c.d") + fields = parse_one("a.b.c.d.e") self.assertIsInstance(fields, exp.Dot) - self.assertEqual(fields.text("expression"), "d") - self.assertEqual(fields.this.text("expression"), "c") + self.assertEqual(fields.text("expression"), "e") column = fields.find(exp.Column) - self.assertEqual(column.name, "b") - self.assertEqual(column.table, "a") + self.assertEqual(column.name, "d") + self.assertEqual(column.table, "c") + self.assertEqual(column.db, "b") + self.assertEqual(column.catalog, "a") column = parse_one("a[0].b") self.assertIsInstance(column, exp.Dot) @@ -505,8 +508,8 @@ class TestExpressions(unittest.TestCase): self.assertIsInstance(parse_one("*"), exp.Star) def test_text(self): - column = parse_one("a.b.c") - self.assertEqual(column.text("expression"), "c") + column = parse_one("a.b.c.d.e") + self.assertEqual(column.text("expression"), "e") self.assertEqual(column.text("y"), "") self.assertEqual(parse_one("select * from x.y").find(exp.Table).text("db"), "x") self.assertEqual(parse_one("select *").name, "") diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index b6993ba..8ddd95f 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -175,7 +175,7 @@ class TestOptimizer(unittest.TestCase): def pushdown_projections(expression, **kwargs): expression = optimizer.qualify_tables.qualify_tables(expression) expression = optimizer.qualify_columns.qualify_columns(expression, **kwargs) - expression = optimizer.pushdown_projections.pushdown_projections(expression) + expression = optimizer.pushdown_projections.pushdown_projections(expression, **kwargs) return expression self.check_file("pushdown_projections", pushdown_projections, schema=self.schema) @@ -519,6 +519,8 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') concat_expr.right.expressions[0].type.this, exp.DataType.Type.VARCHAR ) # x.cola (arg) + annotate_types(parse_one("select x from y lateral view explode(y) as x")).expressions[0] + def test_null_annotation(self): expression = annotate_types(parse_one("SELECT NULL + 2 AS col")).expressions[0].this self.assertEqual(expression.left.type.this, exp.DataType.Type.NULL) diff --git a/tests/test_tokens.py b/tests/test_tokens.py index 2c3b874..d30c445 100644 --- a/tests/test_tokens.py +++ b/tests/test_tokens.py @@ -38,7 +38,8 @@ class TestTokens(unittest.TestCase): tokens, [ (TokenType.SELECT, "SELECT"), - (TokenType.BLOCK_START, "{{"), + (TokenType.L_BRACE, "{"), + (TokenType.L_BRACE, "{"), (TokenType.VAR, "x"), (TokenType.R_BRACE, "}"), (TokenType.R_BRACE, "}"), |