diff options
Diffstat (limited to '')
-rw-r--r-- | tests/dialects/test_bigquery.py | 11 | ||||
-rw-r--r-- | tests/dialects/test_clickhouse.py | 7 | ||||
-rw-r--r-- | tests/dialects/test_dialect.py | 31 | ||||
-rw-r--r-- | tests/dialects/test_hive.py | 26 | ||||
-rw-r--r-- | tests/dialects/test_mysql.py | 11 | ||||
-rw-r--r-- | tests/dialects/test_postgres.py | 8 | ||||
-rw-r--r-- | tests/dialects/test_presto.py | 13 | ||||
-rw-r--r-- | tests/dialects/test_redshift.py | 82 | ||||
-rw-r--r-- | tests/dialects/test_spark.py | 9 | ||||
-rw-r--r-- | tests/dialects/test_teradata.py | 23 | ||||
-rw-r--r-- | tests/dialects/test_tsql.py | 7 | ||||
-rw-r--r-- | tests/fixtures/identity.sql | 29 | ||||
-rw-r--r-- | tests/fixtures/pretty.sql | 20 | ||||
-rw-r--r-- | tests/test_expressions.py | 10 | ||||
-rw-r--r-- | tests/test_optimizer.py | 4 | ||||
-rw-r--r-- | tests/test_parser.py | 6 | ||||
-rw-r--r-- | tests/test_transpile.py | 8 |
17 files changed, 281 insertions, 24 deletions
diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py index c61a2f3..e5b1c94 100644 --- a/tests/dialects/test_bigquery.py +++ b/tests/dialects/test_bigquery.py @@ -6,6 +6,8 @@ class TestBigQuery(Validator): dialect = "bigquery" def test_bigquery(self): + self.validate_identity("SELECT STRUCT<ARRAY<STRING>>(['2023-01-17'])") + self.validate_identity("SELECT * FROM q UNPIVOT(values FOR quarter IN (b, c))") self.validate_all( "REGEXP_CONTAINS('foo', '.*')", read={"bigquery": "REGEXP_CONTAINS('foo', '.*')"}, @@ -42,6 +44,15 @@ class TestBigQuery(Validator): }, ) self.validate_all( + r"'\\'", + write={ + "bigquery": r"'\\'", + "duckdb": r"'\'", + "presto": r"'\'", + "hive": r"'\\'", + }, + ) + self.validate_all( R'R"""/\*.*\*/"""', write={ "bigquery": R"'/\\*.*\\*/'", diff --git a/tests/dialects/test_clickhouse.py b/tests/dialects/test_clickhouse.py index 109e9f3..2827dd4 100644 --- a/tests/dialects/test_clickhouse.py +++ b/tests/dialects/test_clickhouse.py @@ -17,6 +17,7 @@ class TestClickhouse(Validator): self.validate_identity("SELECT quantile(0.5)(a)") self.validate_identity("SELECT quantiles(0.5)(a) AS x FROM t") self.validate_identity("SELECT * FROM foo WHERE x GLOBAL IN (SELECT * FROM bar)") + self.validate_identity("position(a, b)") self.validate_all( "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname", @@ -47,3 +48,9 @@ class TestClickhouse(Validator): "clickhouse": "SELECT quantileIf(0.5)(a, TRUE)", }, ) + + def test_cte(self): + self.validate_identity("WITH 'x' AS foo SELECT foo") + self.validate_identity("WITH SUM(bytes) AS foo SELECT foo FROM system.parts") + self.validate_identity("WITH (SELECT foo) AS bar SELECT bar + 5") + self.validate_identity("WITH test1 AS (SELECT i + 1, j + 1 FROM test1) SELECT * FROM test1") diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py index 284a30d..b2f4676 100644 --- a/tests/dialects/test_dialect.py +++ b/tests/dialects/test_dialect.py @@ -14,7 +14,7 @@ class Validator(unittest.TestCase): self.assertEqual(write_sql or sql, expression.sql(dialect=self.dialect)) return expression - def validate_all(self, sql, read=None, write=None, pretty=False): + def validate_all(self, sql, read=None, write=None, pretty=False, identify=False): """ Validate that: 1. Everything in `read` transpiles to `sql` @@ -32,7 +32,10 @@ class Validator(unittest.TestCase): with self.subTest(f"{read_dialect} -> {sql}"): self.assertEqual( parse_one(read_sql, read_dialect).sql( - self.dialect, unsupported_level=ErrorLevel.IGNORE, pretty=pretty + self.dialect, + unsupported_level=ErrorLevel.IGNORE, + pretty=pretty, + identify=identify, ), sql, ) @@ -48,6 +51,7 @@ class Validator(unittest.TestCase): write_dialect, unsupported_level=ErrorLevel.IGNORE, pretty=pretty, + identify=identify, ), write_sql, ) @@ -76,7 +80,7 @@ class TestDialect(Validator): "oracle": "CAST(a AS CLOB)", "postgres": "CAST(a AS TEXT)", "presto": "CAST(a AS VARCHAR)", - "redshift": "CAST(a AS TEXT)", + "redshift": "CAST(a AS VARCHAR(MAX))", "snowflake": "CAST(a AS TEXT)", "spark": "CAST(a AS STRING)", "starrocks": "CAST(a AS STRING)", @@ -155,7 +159,7 @@ class TestDialect(Validator): "oracle": "CAST(a AS CLOB)", "postgres": "CAST(a AS TEXT)", "presto": "CAST(a AS VARCHAR)", - "redshift": "CAST(a AS TEXT)", + "redshift": "CAST(a AS VARCHAR(MAX))", "snowflake": "CAST(a AS TEXT)", "spark": "CAST(a AS STRING)", "starrocks": "CAST(a AS STRING)", @@ -344,6 +348,7 @@ class TestDialect(Validator): "duckdb": "CAST('2020-01-01' AS TIMESTAMP)", "hive": "CAST('2020-01-01' AS TIMESTAMP)", "presto": "CAST('2020-01-01' AS TIMESTAMP)", + "sqlite": "'2020-01-01'", }, ) self.validate_all( @@ -373,7 +378,7 @@ class TestDialect(Validator): "duckdb": "CAST(x AS TEXT)", "hive": "CAST(x AS STRING)", "presto": "CAST(x AS VARCHAR)", - "redshift": "CAST(x AS TEXT)", + "redshift": "CAST(x AS VARCHAR(MAX))", }, ) self.validate_all( @@ -488,7 +493,9 @@ class TestDialect(Validator): "mysql": "DATE_ADD(x, INTERVAL 1 DAY)", "postgres": "x + INTERVAL '1' 'day'", "presto": "DATE_ADD('day', 1, x)", + "snowflake": "DATEADD(x, 1, 'day')", "spark": "DATE_ADD(x, 1)", + "sqlite": "DATE(x, '1 day')", "starrocks": "DATE_ADD(x, INTERVAL 1 DAY)", "tsql": "DATEADD(day, 1, x)", }, @@ -594,6 +601,7 @@ class TestDialect(Validator): "hive": "TO_DATE(x)", "presto": "CAST(DATE_PARSE(x, '%Y-%m-%d') AS DATE)", "spark": "TO_DATE(x)", + "sqlite": "x", }, ) self.validate_all( @@ -955,7 +963,7 @@ class TestDialect(Validator): }, ) self.validate_all( - "STR_POSITION('a', x)", + "STR_POSITION(x, 'a')", write={ "drill": "STRPOS(x, 'a')", "duckdb": "STRPOS(x, 'a')", @@ -971,7 +979,7 @@ class TestDialect(Validator): "POSITION('a', x, 3)", write={ "drill": "STRPOS(SUBSTR(x, 3), 'a') + 3 - 1", - "presto": "STRPOS(SUBSTR(x, 3), 'a') + 3 - 1", + "presto": "STRPOS(x, 'a', 3)", "spark": "LOCATE('a', x, 3)", "clickhouse": "position(x, 'a', 3)", "snowflake": "POSITION('a', x, 3)", @@ -982,9 +990,10 @@ class TestDialect(Validator): "CONCAT_WS('-', 'a', 'b')", write={ "duckdb": "CONCAT_WS('-', 'a', 'b')", - "presto": "ARRAY_JOIN(ARRAY['a', 'b'], '-')", + "presto": "CONCAT_WS('-', 'a', 'b')", "hive": "CONCAT_WS('-', 'a', 'b')", "spark": "CONCAT_WS('-', 'a', 'b')", + "trino": "CONCAT_WS('-', 'a', 'b')", }, ) @@ -992,9 +1001,10 @@ class TestDialect(Validator): "CONCAT_WS('-', x)", write={ "duckdb": "CONCAT_WS('-', x)", - "presto": "ARRAY_JOIN(x, '-')", "hive": "CONCAT_WS('-', x)", + "presto": "CONCAT_WS('-', x)", "spark": "CONCAT_WS('-', x)", + "trino": "CONCAT_WS('-', x)", }, ) self.validate_all( @@ -1118,6 +1128,7 @@ class TestDialect(Validator): self.validate_all( "SELECT x FROM y OFFSET 10 FETCH FIRST 3 ROWS ONLY", write={ + "sqlite": "SELECT x FROM y LIMIT 3 OFFSET 10", "oracle": "SELECT x FROM y OFFSET 10 ROWS FETCH FIRST 3 ROWS ONLY", }, ) @@ -1197,7 +1208,7 @@ class TestDialect(Validator): "oracle": "CREATE TABLE t (b1 BLOB, b2 BLOB(1024), c1 CLOB, c2 CLOB(1024))", "postgres": "CREATE TABLE t (b1 BYTEA, b2 BYTEA(1024), c1 TEXT, c2 TEXT(1024))", "sqlite": "CREATE TABLE t (b1 BLOB, b2 BLOB(1024), c1 TEXT, c2 TEXT(1024))", - "redshift": "CREATE TABLE t (b1 VARBYTE, b2 VARBYTE(1024), c1 TEXT, c2 TEXT(1024))", + "redshift": "CREATE TABLE t (b1 VARBYTE, b2 VARBYTE(1024), c1 VARCHAR(MAX), c2 VARCHAR(1024))", }, ) diff --git a/tests/dialects/test_hive.py b/tests/dialects/test_hive.py index bbf00b1..d485593 100644 --- a/tests/dialects/test_hive.py +++ b/tests/dialects/test_hive.py @@ -357,6 +357,30 @@ class TestHive(Validator): }, ) self.validate_all( + "SELECT 1a_1a FROM test_a", + write={ + "spark": "SELECT 1a_1a FROM test_a", + }, + ) + self.validate_all( + "SELECT 1a AS 1a_1a FROM test_a", + write={ + "spark": "SELECT 1a AS 1a_1a FROM test_a", + }, + ) + self.validate_all( + "CREATE TABLE test_table (1a STRING)", + write={ + "spark": "CREATE TABLE test_table (1a STRING)", + }, + ) + self.validate_all( + "CREATE TABLE test_table2 (1a_1a STRING)", + write={ + "spark": "CREATE TABLE test_table2 (1a_1a STRING)", + }, + ) + self.validate_all( "PERCENTILE(x, 0.5)", write={ "duckdb": "QUANTILE(x, 0.5)", @@ -420,7 +444,7 @@ class TestHive(Validator): "LOCATE('a', x, 3)", write={ "duckdb": "STRPOS(SUBSTR(x, 3), 'a') + 3 - 1", - "presto": "STRPOS(SUBSTR(x, 3), 'a') + 3 - 1", + "presto": "STRPOS(x, 'a', 3)", "hive": "LOCATE('a', x, 3)", "spark": "LOCATE('a', x, 3)", }, diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py index 7cd686d..dfd2f8e 100644 --- a/tests/dialects/test_mysql.py +++ b/tests/dialects/test_mysql.py @@ -65,6 +65,17 @@ class TestMySQL(Validator): self.validate_identity("SET GLOBAL TRANSACTION ISOLATION LEVEL REPEATABLE READ, READ WRITE") self.validate_identity("SELECT SCHEMA()") + def test_types(self): + self.validate_all( + "CAST(x AS MEDIUMTEXT) + CAST(y AS LONGTEXT)", + read={ + "mysql": "CAST(x AS MEDIUMTEXT) + CAST(y AS LONGTEXT)", + }, + write={ + "spark": "CAST(x AS TEXT) + CAST(y AS TEXT)", + }, + ) + def test_canonical_functions(self): self.validate_identity("SELECT LEFT('str', 2)", "SELECT SUBSTRING('str', 1, 2)") self.validate_identity("SELECT INSTR('str', 'substr')", "SELECT LOCATE('substr', 'str')") diff --git a/tests/dialects/test_postgres.py b/tests/dialects/test_postgres.py index 583d349..2351e3b 100644 --- a/tests/dialects/test_postgres.py +++ b/tests/dialects/test_postgres.py @@ -46,14 +46,6 @@ class TestPostgres(Validator): " CONSTRAINT valid_discount CHECK (price > discounted_price))" }, ) - self.validate_all( - "CREATE TABLE products (x INT GENERATED BY DEFAULT AS IDENTITY)", - write={"postgres": "CREATE TABLE products (x INT GENERATED BY DEFAULT AS IDENTITY)"}, - ) - self.validate_all( - "CREATE TABLE products (x INT GENERATED ALWAYS AS IDENTITY)", - write={"postgres": "CREATE TABLE products (x INT GENERATED ALWAYS AS IDENTITY)"}, - ) with self.assertRaises(ParseError): transpile("CREATE TABLE products (price DECIMAL CHECK price > 0)", read="postgres") diff --git a/tests/dialects/test_presto.py b/tests/dialects/test_presto.py index ee535e9..195e382 100644 --- a/tests/dialects/test_presto.py +++ b/tests/dialects/test_presto.py @@ -152,6 +152,10 @@ class TestPresto(Validator): "spark": "FROM_UNIXTIME(x)", }, ) + self.validate_identity("FROM_UNIXTIME(a, b)") + self.validate_identity("FROM_UNIXTIME(a, b, c)") + self.validate_identity("TRIM(a, b)") + self.validate_identity("VAR_POP(a)") self.validate_all( "TO_UNIXTIME(x)", write={ @@ -302,6 +306,7 @@ class TestPresto(Validator): ) def test_presto(self): + self.validate_identity("SELECT BOOL_OR(a > 10) FROM asd AS T(a)") self.validate_all( 'SELECT a."b" FROM "foo"', write={ @@ -443,8 +448,10 @@ class TestPresto(Validator): "spark": UnsupportedError, }, ) + self.validate_identity("SELECT * FROM (VALUES (1))") self.validate_identity("START TRANSACTION READ WRITE, ISOLATION LEVEL SERIALIZABLE") self.validate_identity("START TRANSACTION ISOLATION LEVEL REPEATABLE READ") + self.validate_identity("APPROX_PERCENTILE(a, b, c, d)") def test_encode_decode(self): self.validate_all( @@ -460,6 +467,12 @@ class TestPresto(Validator): }, ) self.validate_all( + "FROM_UTF8(x, y)", + write={ + "presto": "FROM_UTF8(x, y)", + }, + ) + self.validate_all( "ENCODE(x, 'utf-8')", write={ "presto": "TO_UTF8(x)", diff --git a/tests/dialects/test_redshift.py b/tests/dialects/test_redshift.py index f650c98..e20661e 100644 --- a/tests/dialects/test_redshift.py +++ b/tests/dialects/test_redshift.py @@ -89,7 +89,9 @@ class TestRedshift(Validator): self.validate_identity( "SELECT COUNT(*) FROM event WHERE eventname LIKE '%Ring%' OR eventname LIKE '%Die%'" ) - self.validate_identity("CREATE TABLE SOUP DISTKEY(soup1) SORTKEY(soup2) DISTSTYLE ALL") + self.validate_identity( + "CREATE TABLE SOUP (LIKE other_table) DISTKEY(soup1) SORTKEY(soup2) DISTSTYLE ALL" + ) self.validate_identity( "CREATE TABLE sales (salesid INTEGER NOT NULL) DISTKEY(listid) COMPOUND SORTKEY(listid, sellerid) DISTSTYLE AUTO" ) @@ -102,3 +104,81 @@ class TestRedshift(Validator): self.validate_identity( "CREATE TABLE SOUP (SOUP1 VARCHAR(50) NOT NULL ENCODE ZSTD, SOUP2 VARCHAR(70) NULL ENCODE DELTA)" ) + + def test_values(self): + self.validate_all( + "SELECT a, b FROM (VALUES (1, 2)) AS t (a, b)", + write={ + "redshift": "SELECT a, b FROM (SELECT 1 AS a, 2 AS b) AS t", + }, + ) + self.validate_all( + "SELECT a, b FROM (VALUES (1, 2), (3, 4)) AS t (a, b)", + write={ + "redshift": "SELECT a, b FROM (SELECT 1 AS a, 2 AS b UNION ALL SELECT 3, 4) AS t", + }, + ) + self.validate_all( + "SELECT a, b FROM (VALUES (1, 2), (3, 4), (5, 6), (7, 8)) AS t (a, b)", + write={ + "redshift": "SELECT a, b FROM (SELECT 1 AS a, 2 AS b UNION ALL SELECT 3, 4 UNION ALL SELECT 5, 6 UNION ALL SELECT 7, 8) AS t", + }, + ) + self.validate_all( + "INSERT INTO t(a) VALUES (1), (2), (3)", + write={ + "redshift": "INSERT INTO t (a) VALUES (1), (2), (3)", + }, + ) + self.validate_all( + "INSERT INTO t(a, b) SELECT a, b FROM (VALUES (1, 2), (3, 4)) AS t (a, b)", + write={ + "redshift": "INSERT INTO t (a, b) SELECT a, b FROM (SELECT 1 AS a, 2 AS b UNION ALL SELECT 3, 4) AS t", + }, + ) + self.validate_all( + "INSERT INTO t(a, b) VALUES (1, 2), (3, 4)", + write={ + "redshift": "INSERT INTO t (a, b) VALUES (1, 2), (3, 4)", + }, + ) + + def test_create_table_like(self): + self.validate_all( + "CREATE TABLE t1 LIKE t2", + write={ + "redshift": "CREATE TABLE t1 (LIKE t2)", + }, + ) + self.validate_all( + "CREATE TABLE SOUP (LIKE other_table) DISTKEY(soup1) SORTKEY(soup2) DISTSTYLE ALL", + write={ + "redshift": "CREATE TABLE SOUP (LIKE other_table) DISTKEY(soup1) SORTKEY(soup2) DISTSTYLE ALL", + }, + ) + + def test_rename_table(self): + self.validate_all( + "ALTER TABLE db.t1 RENAME TO db.t2", + write={ + "spark": "ALTER TABLE db.t1 RENAME TO db.t2", + "redshift": "ALTER TABLE db.t1 RENAME TO t2", + }, + ) + + def test_varchar_max(self): + self.validate_all( + "CREATE TABLE TEST (cola VARCHAR(MAX))", + write={ + "redshift": 'CREATE TABLE "TEST" ("cola" VARCHAR(MAX))', + }, + identify=True, + ) + + def test_no_schema_binding(self): + self.validate_all( + "CREATE OR REPLACE VIEW v1 AS SELECT cola, colb FROM t1 WITH NO SCHEMA BINDING", + write={ + "redshift": "CREATE OR REPLACE VIEW v1 AS SELECT cola, colb FROM t1 WITH NO SCHEMA BINDING", + }, + ) diff --git a/tests/dialects/test_spark.py b/tests/dialects/test_spark.py index f287a89..fad858c 100644 --- a/tests/dialects/test_spark.py +++ b/tests/dialects/test_spark.py @@ -307,5 +307,12 @@ TBLPROPERTIES ( def test_iif(self): self.validate_all( - "SELECT IIF(cond, 'True', 'False')", write={"spark": "SELECT IF(cond, 'True', 'False')"} + "SELECT IIF(cond, 'True', 'False')", + write={"spark": "SELECT IF(cond, 'True', 'False')"}, + ) + + def test_bool_or(self): + self.validate_all( + "SELECT a, LOGICAL_OR(b) FROM table GROUP BY a", + write={"duckdb": "SELECT a, BOOL_OR(b) FROM table GROUP BY a"}, ) diff --git a/tests/dialects/test_teradata.py b/tests/dialects/test_teradata.py new file mode 100644 index 0000000..e56de25 --- /dev/null +++ b/tests/dialects/test_teradata.py @@ -0,0 +1,23 @@ +from tests.dialects.test_dialect import Validator + + +class TestTeradata(Validator): + dialect = "teradata" + + def test_translate(self): + self.validate_all( + "TRANSLATE(x USING LATIN_TO_UNICODE)", + write={ + "teradata": "CAST(x AS CHAR CHARACTER SET UNICODE)", + }, + ) + self.validate_identity("CAST(x AS CHAR CHARACTER SET UNICODE)") + + def test_update(self): + self.validate_all( + "UPDATE A FROM schema.tableA AS A, (SELECT col1 FROM schema.tableA GROUP BY col1) AS B SET col2 = '' WHERE A.col1 = B.col1", + write={ + "teradata": "UPDATE A FROM schema.tableA AS A, (SELECT col1 FROM schema.tableA GROUP BY col1) AS B SET col2 = '' WHERE A.col1 = B.col1", + "mysql": "UPDATE A SET col2 = '' FROM schema.tableA AS A, (SELECT col1 FROM schema.tableA GROUP BY col1) AS B WHERE A.col1 = B.col1", + }, + ) diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py index b74c05f..d2972ca 100644 --- a/tests/dialects/test_tsql.py +++ b/tests/dialects/test_tsql.py @@ -5,6 +5,13 @@ class TestTSQL(Validator): dialect = "tsql" def test_tsql(self): + self.validate_identity("DECLARE @TestVariable AS VARCHAR(100)='Save Our Planet'") + self.validate_identity("PRINT @TestVariable") + self.validate_identity("SELECT Employee_ID, Department_ID FROM @MyTableVar") + self.validate_identity("INSERT INTO @TestTable VALUES (1, 'Value1', 12, 20)") + self.validate_identity( + "SELECT x FROM @MyTableVar AS m JOIN Employee ON m.EmployeeID = Employee.EmployeeID" + ) self.validate_identity('SELECT "x"."y" FROM foo') self.validate_identity("SELECT * FROM #foo") self.validate_identity("SELECT * FROM ##foo") diff --git a/tests/fixtures/identity.sql b/tests/fixtures/identity.sql index beb5703..4e21d2b 100644 --- a/tests/fixtures/identity.sql +++ b/tests/fixtures/identity.sql @@ -59,6 +59,8 @@ map.x SELECT call.x a.b.INT(1.234) INT(x / 100) +time * 100 +int * 100 x IN (-1, 1) x IN ('a', 'a''a') x IN ((1)) @@ -69,6 +71,11 @@ x IS TRUE x IS FALSE x IS TRUE IS TRUE x LIKE y IS TRUE +MAP() +GREATEST(x) +LEAST(y) +MAX(a, b) +MIN(a, b) time zone ARRAY<TEXT> @@ -133,6 +140,7 @@ x AT TIME ZONE 'UTC' CAST('2025-11-20 00:00:00+00' AS TIMESTAMP) AT TIME ZONE 'Africa/Cairo' SET x = 1 SET -v +SET x = ';' COMMIT USE db NOT 1 @@ -170,6 +178,7 @@ SELECT COUNT(DISTINCT a, b) SELECT COUNT(DISTINCT a, b + 1) SELECT SUM(DISTINCT x) SELECT SUM(x IGNORE NULLS) AS x +SELECT TRUNCATE(a, b) SELECT ARRAY_AGG(DISTINCT x IGNORE NULLS ORDER BY a, b DESC LIMIT 10) AS x SELECT ARRAY_AGG(STRUCT(x, x AS y) ORDER BY z DESC) AS x SELECT LAST_VALUE(x IGNORE NULLS) OVER y AS x @@ -622,7 +631,7 @@ SELECT 1 /* c1 */ + 2 /* c2 */ + 3 /* c3 */ SELECT 1 /* c1 */ + 2 /* c2 */, 3 /* c3 */ SELECT x FROM a.b.c /* x */, e.f.g /* x */ SELECT FOO(x /* c */) /* FOO */, b /* b */ -SELECT FOO(x /* c1 */ + y /* c2 */ + BLA(5 /* c3 */)) FROM VALUES (1 /* c4 */, "test" /* c5 */) /* c6 */ +SELECT FOO(x /* c1 */ + y /* c2 */ + BLA(5 /* c3 */)) FROM (VALUES (1 /* c4 */, "test" /* c5 */)) /* c6 */ SELECT a FROM x WHERE a COLLATE 'utf8_general_ci' = 'b' SELECT x AS INTO FROM bla SELECT * INTO newevent FROM event @@ -643,3 +652,21 @@ ALTER TABLE integers ALTER COLUMN i DROP DEFAULT ALTER TABLE mydataset.mytable DROP COLUMN A, DROP COLUMN IF EXISTS B ALTER TABLE mydataset.mytable ADD COLUMN A TEXT, ADD COLUMN IF NOT EXISTS B INT SELECT div.a FROM test_table AS div +WITH view AS (SELECT 1 AS x) SELECT * FROM view +CREATE TABLE asd AS SELECT asd FROM asd WITH NO DATA +CREATE TABLE asd AS SELECT asd FROM asd WITH DATA +ARRAY<STRUCT<INT, DOUBLE, ARRAY<INT>>> +ARRAY<INT>[1, 2, 3] +ARRAY<INT>[] +STRUCT<x VARCHAR(10)> +STRUCT<x VARCHAR(10)>("bla") +STRUCT<VARCHAR(10)>("bla") +STRUCT<INT>(5) +STRUCT<DATE>("2011-05-05") +STRUCT<x INT, y TEXT>(1, t.str_col) +SELECT CAST(NULL AS ARRAY<INT>) IS NULL AS array_is_null +CREATE TABLE products (x INT GENERATED BY DEFAULT AS IDENTITY) +CREATE TABLE products (x INT GENERATED ALWAYS AS IDENTITY) +CREATE TABLE IF NOT EXISTS customer (pk BIGINT NOT NULL GENERATED ALWAYS AS IDENTITY (INCREMENT BY 1)) +CREATE TABLE customer (pk BIGINT NOT NULL GENERATED ALWAYS AS IDENTITY (START WITH 10 INCREMENT BY 1)) +CREATE TABLE customer (pk BIGINT NOT NULL GENERATED ALWAYS AS IDENTITY (START WITH 10)) diff --git a/tests/fixtures/pretty.sql b/tests/fixtures/pretty.sql index 067fe77..64806eb 100644 --- a/tests/fixtures/pretty.sql +++ b/tests/fixtures/pretty.sql @@ -322,3 +322,23 @@ SELECT * /* multi line comment */; +WITH table_data AS ( + SELECT 'bob' AS name, ARRAY['banana', 'apple', 'orange'] AS fruit_basket +) +SELECT + name, + fruit, + basket_index +FROM table_data +CROSS JOIN UNNEST(fruit_basket) AS fruit WITH OFFSET basket_index; +WITH table_data AS ( + SELECT + 'bob' AS name, + ARRAY('banana', 'apple', 'orange') AS fruit_basket +) +SELECT + name, + fruit, + basket_index +FROM table_data +CROSS JOIN UNNEST(fruit_basket) AS fruit WITH OFFSET AS basket_index; diff --git a/tests/test_expressions.py b/tests/test_expressions.py index 906e08c..9e5f988 100644 --- a/tests/test_expressions.py +++ b/tests/test_expressions.py @@ -624,6 +624,10 @@ FROM foo""", self.assertEqual(catalog_db_and_table.args.get("catalog"), exp.to_identifier("catalog")) with self.assertRaises(ValueError): exp.to_table(1) + empty_string = exp.to_table("") + self.assertEqual(empty_string.name, "") + self.assertIsNone(table_only.args.get("db")) + self.assertIsNone(table_only.args.get("catalog")) def test_to_column(self): column_only = exp.to_column("column_name") @@ -715,3 +719,9 @@ FROM foo""", self.assertEqual(exp.DataType.build("OBJECT").sql(), "OBJECT") self.assertEqual(exp.DataType.build("NULL").sql(), "NULL") self.assertEqual(exp.DataType.build("UNKNOWN").sql(), "UNKNOWN") + + def test_rename_table(self): + self.assertEqual( + exp.rename_table("t1", "t2").sql(), + "ALTER TABLE t1 RENAME TO t2", + ) diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index 887f427..af21679 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -6,7 +6,7 @@ from pandas.testing import assert_frame_equal import sqlglot from sqlglot import exp, optimizer, parse_one -from sqlglot.errors import OptimizeError +from sqlglot.errors import OptimizeError, SchemaError from sqlglot.optimizer.annotate_types import annotate_types from sqlglot.optimizer.scope import build_scope, traverse_scope, walk_in_scope from sqlglot.schema import MappingSchema @@ -161,7 +161,7 @@ class TestOptimizer(unittest.TestCase): def test_qualify_columns__invalid(self): for sql in load_sql_fixtures("optimizer/qualify_columns__invalid.sql"): with self.subTest(sql): - with self.assertRaises(OptimizeError): + with self.assertRaises((OptimizeError, SchemaError)): optimizer.qualify_columns.qualify_columns(parse_one(sql), schema=self.schema) def test_lower_identities(self): diff --git a/tests/test_parser.py b/tests/test_parser.py index 03b801b..dbde437 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -325,3 +325,9 @@ class TestParser(unittest.TestCase): "Expected table name", logger, ) + + def test_rename_table(self): + self.assertEqual( + parse_one("ALTER TABLE foo RENAME TO bar").sql(), + "ALTER TABLE foo RENAME TO bar", + ) diff --git a/tests/test_transpile.py b/tests/test_transpile.py index 3a7fea4..3e094f5 100644 --- a/tests/test_transpile.py +++ b/tests/test_transpile.py @@ -272,6 +272,11 @@ FROM bar /* comment 5 */, tbl /* comment 6 */""", "WITH A(filter) AS (VALUES 1, 2, 3) SELECT * FROM A WHERE filter >= 2", "WITH A(filter) AS (VALUES (1), (2), (3)) SELECT * FROM A WHERE filter >= 2", ) + self.validate( + "SELECT BOOL_OR(a > 10) FROM (VALUES 1, 2, 15) AS T(a)", + "SELECT BOOL_OR(a > 10) FROM (VALUES (1), (2), (15)) AS T(a)", + write="presto", + ) def test_alter(self): self.validate( @@ -447,6 +452,9 @@ FROM bar /* comment 5 */, tbl /* comment 6 */""", self.assertEqual(generated, pretty) self.assertEqual(parse_one(sql), parse_one(pretty)) + def test_pretty_line_breaks(self): + self.assertEqual(transpile("SELECT '1\n2'", pretty=True)[0], "SELECT\n '1\n2'") + @mock.patch("sqlglot.parser.logger") def test_error_level(self, logger): invalid = "x + 1. (" |