diff options
Diffstat (limited to '')
-rw-r--r-- | tests/dialects/test_bigquery.py | 21 | ||||
-rw-r--r-- | tests/dialects/test_dialect.py | 17 | ||||
-rw-r--r-- | tests/dialects/test_duckdb.py | 8 | ||||
-rw-r--r-- | tests/dialects/test_hive.py | 10 | ||||
-rw-r--r-- | tests/dialects/test_mysql.py | 36 | ||||
-rw-r--r-- | tests/dialects/test_presto.py | 27 | ||||
-rw-r--r-- | tests/dialects/test_snowflake.py | 73 | ||||
-rw-r--r-- | tests/dialects/test_spark.py | 14 | ||||
-rw-r--r-- | tests/dialects/test_tsql.py | 18 |
9 files changed, 196 insertions, 28 deletions
diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py index 1337c3d..c929e59 100644 --- a/tests/dialects/test_bigquery.py +++ b/tests/dialects/test_bigquery.py @@ -236,3 +236,24 @@ class TestBigQuery(Validator): "snowflake": "SELECT a FROM test WHERE a = 1 GROUP BY a HAVING a = 2 QUALIFY z ORDER BY a NULLS FIRST LIMIT 10", }, ) + self.validate_all( + "SELECT cola, colb FROM (VALUES (1, 'test')) AS tab(cola, colb)", + write={ + "spark": "SELECT cola, colb FROM (VALUES (1, 'test')) AS tab(cola, colb)", + "bigquery": "SELECT cola, colb FROM UNNEST([STRUCT(1 AS cola, 'test' AS colb)])", + "snowflake": "SELECT cola, colb FROM (VALUES (1, 'test')) AS tab(cola, colb)", + }, + ) + self.validate_all( + "SELECT * FROM (SELECT a, b, c FROM test) PIVOT(SUM(b) d, COUNT(*) e FOR c IN ('x', 'y'))", + write={ + "bigquery": "SELECT * FROM (SELECT a, b, c FROM test) PIVOT(SUM(b) AS d, COUNT(*) AS e FOR c IN ('x', 'y'))", + }, + ) + + def test_user_defined_functions(self): + self.validate_identity( + "CREATE TEMPORARY FUNCTION a(x FLOAT64, y FLOAT64) RETURNS FLOAT64 LANGUAGE js AS 'return x*y;'" + ) + self.validate_identity("CREATE TEMPORARY FUNCTION a(x FLOAT64, y FLOAT64) AS ((x + 4) / y)") + self.validate_identity("CREATE TABLE FUNCTION a(x INT64) RETURNS TABLE <q STRING, r INT64> AS SELECT s, t") diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py index 4e0a3c6..e0ec824 100644 --- a/tests/dialects/test_dialect.py +++ b/tests/dialects/test_dialect.py @@ -13,9 +13,6 @@ from sqlglot import ( class Validator(unittest.TestCase): dialect = None - def validate(self, sql, target, **kwargs): - self.assertEqual(transpile(sql, **kwargs)[0], target) - def validate_identity(self, sql): self.assertEqual(transpile(sql, read=self.dialect, write=self.dialect)[0], sql) @@ -258,6 +255,7 @@ class TestDialect(Validator): "duckdb": "EPOCH(STRPTIME('2020-01-01', '%Y-%M-%d'))", "hive": "UNIX_TIMESTAMP('2020-01-01', 'yyyy-mm-dd')", "presto": "TO_UNIXTIME(DATE_PARSE('2020-01-01', '%Y-%i-%d'))", + "starrocks": "UNIX_TIMESTAMP('2020-01-01', '%Y-%i-%d')", }, ) self.validate_all( @@ -266,6 +264,7 @@ class TestDialect(Validator): "duckdb": "CAST('2020-01-01' AS DATE)", "hive": "TO_DATE('2020-01-01')", "presto": "DATE_PARSE('2020-01-01', '%Y-%m-%d %H:%i:%s')", + "starrocks": "TO_DATE('2020-01-01')", }, ) self.validate_all( @@ -341,6 +340,7 @@ class TestDialect(Validator): "duckdb": "STRFTIME(TO_TIMESTAMP(CAST(x AS BIGINT)), y)", "hive": "FROM_UNIXTIME(x, y)", "presto": "DATE_FORMAT(FROM_UNIXTIME(x), y)", + "starrocks": "FROM_UNIXTIME(x, y)", }, ) self.validate_all( @@ -349,6 +349,7 @@ class TestDialect(Validator): "duckdb": "TO_TIMESTAMP(CAST(x AS BIGINT))", "hive": "FROM_UNIXTIME(x)", "presto": "FROM_UNIXTIME(x)", + "starrocks": "FROM_UNIXTIME(x)", }, ) self.validate_all( @@ -841,9 +842,19 @@ class TestDialect(Validator): }, ) self.validate_all( + "POSITION(' ' in x)", + write={ + "duckdb": "STRPOS(x, ' ')", + "postgres": "STRPOS(x, ' ')", + "presto": "STRPOS(x, ' ')", + "spark": "LOCATE(' ', x)", + }, + ) + self.validate_all( "STR_POSITION(x, 'a')", write={ "duckdb": "STRPOS(x, 'a')", + "postgres": "STRPOS(x, 'a')", "presto": "STRPOS(x, 'a')", "spark": "LOCATE('a', x)", }, diff --git a/tests/dialects/test_duckdb.py b/tests/dialects/test_duckdb.py index f52decb..96e51df 100644 --- a/tests/dialects/test_duckdb.py +++ b/tests/dialects/test_duckdb.py @@ -1,3 +1,4 @@ +from sqlglot import ErrorLevel, UnsupportedError, transpile from tests.dialects.test_dialect import Validator @@ -250,3 +251,10 @@ class TestDuckDB(Validator): "spark": "MONTH('2021-03-01')", }, ) + + with self.assertRaises(UnsupportedError): + transpile( + "SELECT a FROM b PIVOT(SUM(x) FOR y IN ('z', 'q'))", + read="duckdb", + unsupported_level=ErrorLevel.IMMEDIATE, + ) diff --git a/tests/dialects/test_hive.py b/tests/dialects/test_hive.py index a9b5168..d335921 100644 --- a/tests/dialects/test_hive.py +++ b/tests/dialects/test_hive.py @@ -127,17 +127,17 @@ class TestHive(Validator): def test_ddl(self): self.validate_all( - "CREATE TABLE test STORED AS parquet TBLPROPERTIES ('x' = '1', 'Z' = '2') AS SELECT 1", + "CREATE TABLE test STORED AS parquet TBLPROPERTIES ('x'='1', 'Z'='2') AS SELECT 1", write={ - "presto": "CREATE TABLE test WITH (FORMAT = 'parquet', x = '1', Z = '2') AS SELECT 1", - "hive": "CREATE TABLE test STORED AS PARQUET TBLPROPERTIES ('x' = '1', 'Z' = '2') AS SELECT 1", - "spark": "CREATE TABLE test USING PARQUET TBLPROPERTIES ('x' = '1', 'Z' = '2') AS SELECT 1", + "presto": "CREATE TABLE test WITH (FORMAT='parquet', x='1', Z='2') AS SELECT 1", + "hive": "CREATE TABLE test STORED AS PARQUET TBLPROPERTIES ('x'='1', 'Z'='2') AS SELECT 1", + "spark": "CREATE TABLE test USING PARQUET TBLPROPERTIES ('x'='1', 'Z'='2') AS SELECT 1", }, ) self.validate_all( "CREATE TABLE x (w STRING) PARTITIONED BY (y INT, z INT)", write={ - "presto": "CREATE TABLE x (w VARCHAR, y INTEGER, z INTEGER) WITH (PARTITIONED_BY = ARRAY['y', 'z'])", + "presto": "CREATE TABLE x (w VARCHAR, y INTEGER, z INTEGER) WITH (PARTITIONED_BY=ARRAY['y', 'z'])", "hive": "CREATE TABLE x (w STRING) PARTITIONED BY (y INT, z INT)", "spark": "CREATE TABLE x (w STRING) PARTITIONED BY (y INT, z INT)", }, diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py index 87a3d64..02dc1ad 100644 --- a/tests/dialects/test_mysql.py +++ b/tests/dialects/test_mysql.py @@ -119,3 +119,39 @@ class TestMySQL(Validator): "sqlite": "GROUP_CONCAT(DISTINCT x ORDER BY y DESC, '')", }, ) + self.validate_identity( + "CREATE TABLE z (a INT) ENGINE=InnoDB AUTO_INCREMENT=1 CHARACTER SET=utf8 COLLATE=utf8_bin COMMENT='x'" + ) + self.validate_identity( + "CREATE TABLE z (a INT) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARACTER SET=utf8 COLLATE=utf8_bin COMMENT='x'" + ) + self.validate_identity( + "CREATE TABLE z (a INT DEFAULT NULL, PRIMARY KEY(a)) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARACTER SET=utf8 COLLATE=utf8_bin COMMENT='x'" + ) + + self.validate_all( + """ + CREATE TABLE `t_customer_account` ( + "id" int(11) NOT NULL AUTO_INCREMENT, + "customer_id" int(11) DEFAULT NULL COMMENT '客户id', + "bank" varchar(100) COLLATE utf8_bin DEFAULT NULL COMMENT '行别', + "account_no" varchar(100) COLLATE utf8_bin DEFAULT NULL COMMENT '账号', + PRIMARY KEY ("id") + ) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARACTER SET=utf8 COLLATE=utf8_bin COMMENT='客户账户表' + """, + write={ + "mysql": """CREATE TABLE `t_customer_account` ( + 'id' INT(11) NOT NULL AUTO_INCREMENT, + 'customer_id' INT(11) DEFAULT NULL COMMENT '客户id', + 'bank' VARCHAR(100) COLLATE utf8_bin DEFAULT NULL COMMENT '行别', + 'account_no' VARCHAR(100) COLLATE utf8_bin DEFAULT NULL COMMENT '账号', + PRIMARY KEY('id') +) +ENGINE=InnoDB +AUTO_INCREMENT=1 +DEFAULT CHARACTER SET=utf8 +COLLATE=utf8_bin +COMMENT='客户账户表'""" + }, + pretty=True, + ) diff --git a/tests/dialects/test_presto.py b/tests/dialects/test_presto.py index 96c299d..b0d9ad9 100644 --- a/tests/dialects/test_presto.py +++ b/tests/dialects/test_presto.py @@ -171,7 +171,7 @@ class TestPresto(Validator): self.validate_all( "CREATE TABLE test WITH (FORMAT = 'PARQUET') AS SELECT 1", write={ - "presto": "CREATE TABLE test WITH (FORMAT = 'PARQUET') AS SELECT 1", + "presto": "CREATE TABLE test WITH (FORMAT='PARQUET') AS SELECT 1", "hive": "CREATE TABLE test STORED AS PARQUET AS SELECT 1", "spark": "CREATE TABLE test USING PARQUET AS SELECT 1", }, @@ -179,15 +179,15 @@ class TestPresto(Validator): self.validate_all( "CREATE TABLE test WITH (FORMAT = 'PARQUET', X = '1', Z = '2') AS SELECT 1", write={ - "presto": "CREATE TABLE test WITH (FORMAT = 'PARQUET', X = '1', Z = '2') AS SELECT 1", - "hive": "CREATE TABLE test STORED AS PARQUET TBLPROPERTIES ('X' = '1', 'Z' = '2') AS SELECT 1", - "spark": "CREATE TABLE test USING PARQUET TBLPROPERTIES ('X' = '1', 'Z' = '2') AS SELECT 1", + "presto": "CREATE TABLE test WITH (FORMAT='PARQUET', X='1', Z='2') AS SELECT 1", + "hive": "CREATE TABLE test STORED AS PARQUET TBLPROPERTIES ('X'='1', 'Z'='2') AS SELECT 1", + "spark": "CREATE TABLE test USING PARQUET TBLPROPERTIES ('X'='1', 'Z'='2') AS SELECT 1", }, ) self.validate_all( - "CREATE TABLE x (w VARCHAR, y INTEGER, z INTEGER) WITH (PARTITIONED_BY = ARRAY['y', 'z'])", + "CREATE TABLE x (w VARCHAR, y INTEGER, z INTEGER) WITH (PARTITIONED_BY=ARRAY['y', 'z'])", write={ - "presto": "CREATE TABLE x (w VARCHAR, y INTEGER, z INTEGER) WITH (PARTITIONED_BY = ARRAY['y', 'z'])", + "presto": "CREATE TABLE x (w VARCHAR, y INTEGER, z INTEGER) WITH (PARTITIONED_BY=ARRAY['y', 'z'])", "hive": "CREATE TABLE x (w STRING) PARTITIONED BY (y INT, z INT)", "spark": "CREATE TABLE x (w STRING) PARTITIONED BY (y INT, z INT)", }, @@ -195,9 +195,9 @@ class TestPresto(Validator): self.validate_all( "CREATE TABLE x WITH (bucket_by = ARRAY['y'], bucket_count = 64) AS SELECT 1 AS y", write={ - "presto": "CREATE TABLE x WITH (bucket_by = ARRAY['y'], bucket_count = 64) AS SELECT 1 AS y", - "hive": "CREATE TABLE x TBLPROPERTIES ('bucket_by' = ARRAY('y'), 'bucket_count' = 64) AS SELECT 1 AS y", - "spark": "CREATE TABLE x TBLPROPERTIES ('bucket_by' = ARRAY('y'), 'bucket_count' = 64) AS SELECT 1 AS y", + "presto": "CREATE TABLE x WITH (bucket_by=ARRAY['y'], bucket_count=64) AS SELECT 1 AS y", + "hive": "CREATE TABLE x TBLPROPERTIES ('bucket_by'=ARRAY('y'), 'bucket_count'=64) AS SELECT 1 AS y", + "spark": "CREATE TABLE x TBLPROPERTIES ('bucket_by'=ARRAY('y'), 'bucket_count'=64) AS SELECT 1 AS y", }, ) self.validate_all( @@ -217,11 +217,12 @@ class TestPresto(Validator): }, ) - self.validate( + self.validate_all( "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname", - "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname", - read="presto", - write="presto", + write={ + "presto": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname", + "spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname NULLS LAST", + }, ) def test_quotes(self): diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py index 165f8e2..b7e39a7 100644 --- a/tests/dialects/test_snowflake.py +++ b/tests/dialects/test_snowflake.py @@ -143,6 +143,31 @@ class TestSnowflake(Validator): "snowflake": r"SELECT 'a \' \\ \\t \\x21 z $ '", }, ) + self.validate_identity("SELECT REGEXP_LIKE(a, b, c)") + self.validate_all( + "SELECT RLIKE(a, b)", + write={ + "snowflake": "SELECT REGEXP_LIKE(a, b)", + }, + ) + self.validate_all( + "SELECT a FROM test SAMPLE BLOCK (0.5) SEED (42)", + write={ + "snowflake": "SELECT a FROM test TABLESAMPLE BLOCK (0.5) SEED (42)", + }, + ) + self.validate_all( + "SELECT a FROM test pivot", + write={ + "snowflake": "SELECT a FROM test AS pivot", + }, + ) + self.validate_all( + "SELECT a FROM test unpivot", + write={ + "snowflake": "SELECT a FROM test AS unpivot", + }, + ) def test_null_treatment(self): self.validate_all( @@ -220,3 +245,51 @@ class TestSnowflake(Validator): "snowflake": "SELECT EXTRACT(month FROM CAST(a AS DATETIME))", }, ) + + def test_semi_structured_types(self): + self.validate_identity("SELECT CAST(a AS VARIANT)") + self.validate_all( + "SELECT a::VARIANT", + write={ + "snowflake": "SELECT CAST(a AS VARIANT)", + "tsql": "SELECT CAST(a AS SQL_VARIANT)", + }, + ) + self.validate_identity("SELECT CAST(a AS ARRAY)") + self.validate_all( + "ARRAY_CONSTRUCT(0, 1, 2)", + write={ + "snowflake": "[0, 1, 2]", + "bigquery": "[0, 1, 2]", + "duckdb": "LIST_VALUE(0, 1, 2)", + "presto": "ARRAY[0, 1, 2]", + "spark": "ARRAY(0, 1, 2)", + }, + ) + self.validate_all( + "SELECT a::OBJECT", + write={ + "snowflake": "SELECT CAST(a AS OBJECT)", + }, + ) + + def test_ddl(self): + self.validate_identity( + "CREATE TABLE a (x DATE, y BIGINT) WITH (PARTITION BY (x), integration='q', auto_refresh=TRUE, file_format=(type = parquet))" + ) + self.validate_identity("CREATE MATERIALIZED VIEW a COMMENT='...' AS SELECT 1 FROM x") + + def test_user_defined_functions(self): + self.validate_all( + "CREATE FUNCTION a(x DATE, y BIGINT) RETURNS ARRAY LANGUAGE JAVASCRIPT AS $$ SELECT 1 $$", + write={ + "snowflake": "CREATE FUNCTION a(x DATE, y BIGINT) RETURNS ARRAY LANGUAGE JAVASCRIPT AS ' SELECT 1 '", + }, + ) + self.validate_all( + "CREATE FUNCTION a() RETURNS TABLE (b INT) AS 'SELECT 1'", + write={ + "snowflake": "CREATE FUNCTION a() RETURNS TABLE (b INT) AS 'SELECT 1'", + "bigquery": "CREATE TABLE FUNCTION a() RETURNS TABLE <b INT64> AS SELECT 1", + }, + ) diff --git a/tests/dialects/test_spark.py b/tests/dialects/test_spark.py index 22f6947..8377e47 100644 --- a/tests/dialects/test_spark.py +++ b/tests/dialects/test_spark.py @@ -34,7 +34,7 @@ class TestSpark(Validator): self.validate_all( "CREATE TABLE x USING ICEBERG PARTITIONED BY (MONTHS(y)) LOCATION 's3://z'", write={ - "presto": "CREATE TABLE x WITH (TABLE_FORMAT = 'ICEBERG', PARTITIONED_BY = ARRAY['MONTHS'])", + "presto": "CREATE TABLE x WITH (TABLE_FORMAT = 'ICEBERG', PARTITIONED_BY=ARRAY['MONTHS'])", "hive": "CREATE TABLE x USING ICEBERG PARTITIONED BY (MONTHS(y)) LOCATION 's3://z'", "spark": "CREATE TABLE x USING ICEBERG PARTITIONED BY (MONTHS(y)) LOCATION 's3://z'", }, @@ -42,7 +42,7 @@ class TestSpark(Validator): self.validate_all( "CREATE TABLE test STORED AS PARQUET AS SELECT 1", write={ - "presto": "CREATE TABLE test WITH (FORMAT = 'PARQUET') AS SELECT 1", + "presto": "CREATE TABLE test WITH (FORMAT='PARQUET') AS SELECT 1", "hive": "CREATE TABLE test STORED AS PARQUET AS SELECT 1", "spark": "CREATE TABLE test USING PARQUET AS SELECT 1", }, @@ -56,9 +56,9 @@ class TestSpark(Validator): ) COMMENT='Test comment: blah' WITH ( - PARTITIONED_BY = ARRAY['date'], - FORMAT = 'ICEBERG', - x = '1' + PARTITIONED_BY=ARRAY['date'], + FORMAT='ICEBERG', + x='1' )""", "hive": """CREATE TABLE blah ( col_a INT @@ -69,7 +69,7 @@ PARTITIONED BY ( ) STORED AS ICEBERG TBLPROPERTIES ( - 'x' = '1' + 'x'='1' )""", "spark": """CREATE TABLE blah ( col_a INT @@ -80,7 +80,7 @@ PARTITIONED BY ( ) USING ICEBERG TBLPROPERTIES ( - 'x' = '1' + 'x'='1' )""", }, pretty=True, diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py index 0619eaa..6b0b39b 100644 --- a/tests/dialects/test_tsql.py +++ b/tests/dialects/test_tsql.py @@ -15,6 +15,14 @@ class TestTSQL(Validator): }, ) + self.validate_all( + "CONVERT(INT, CONVERT(NUMERIC, '444.75'))", + write={ + "mysql": "CAST(CAST('444.75' AS DECIMAL) AS INT)", + "tsql": "CAST(CAST('444.75' AS NUMERIC) AS INTEGER)", + }, + ) + def test_types(self): self.validate_identity("CAST(x AS XML)") self.validate_identity("CAST(x AS UNIQUEIDENTIFIER)") @@ -24,3 +32,13 @@ class TestTSQL(Validator): self.validate_identity("CAST(x AS IMAGE)") self.validate_identity("CAST(x AS SQL_VARIANT)") self.validate_identity("CAST(x AS BIT)") + self.validate_all( + "CAST(x AS DATETIME2)", + read={ + "": "CAST(x AS DATETIME)", + }, + write={ + "mysql": "CAST(x AS DATETIME)", + "tsql": "CAST(x AS DATETIME2)", + }, + ) |