diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/dataframe/unit/test_session_case_sensitivity.py | 6 | ||||
-rw-r--r-- | tests/dialects/test_bigquery.py | 15 | ||||
-rw-r--r-- | tests/dialects/test_duckdb.py | 21 | ||||
-rw-r--r-- | tests/dialects/test_mysql.py | 27 | ||||
-rw-r--r-- | tests/dialects/test_oracle.py | 41 | ||||
-rw-r--r-- | tests/dialects/test_postgres.py | 1 | ||||
-rw-r--r-- | tests/dialects/test_presto.py | 22 | ||||
-rw-r--r-- | tests/fixtures/optimizer/optimizer.sql | 20 | ||||
-rw-r--r-- | tests/test_expressions.py | 3 | ||||
-rw-r--r-- | tests/test_optimizer.py | 11 | ||||
-rw-r--r-- | tests/test_parser.py | 2 |
11 files changed, 141 insertions, 28 deletions
diff --git a/tests/dataframe/unit/test_session_case_sensitivity.py b/tests/dataframe/unit/test_session_case_sensitivity.py index 7e35289..f9119b0 100644 --- a/tests/dataframe/unit/test_session_case_sensitivity.py +++ b/tests/dataframe/unit/test_session_case_sensitivity.py @@ -25,7 +25,7 @@ class TestSessionCaseSensitivity(DataFrameTestBase): '"Test"', {"name": "VARCHAR"}, "name", - '''SELECT "TEST"."NAME" AS "NAME" FROM "Test" AS "TEST"''', + '''SELECT "Test"."NAME" AS "NAME" FROM "Test" AS "Test"''', ), ( "Column has CS while table does not", @@ -41,7 +41,7 @@ class TestSessionCaseSensitivity(DataFrameTestBase): '"Test"', {'"Name"': "VARCHAR"}, '"Name"', - '''SELECT "TEST"."Name" AS "Name" FROM "Test" AS "TEST"''', + '''SELECT "Test"."Name" AS "Name" FROM "Test" AS "Test"''', ), ( "Lowercase CS table and column", @@ -49,7 +49,7 @@ class TestSessionCaseSensitivity(DataFrameTestBase): '"test"', {'"name"': "VARCHAR"}, '"name"', - '''SELECT "TEST"."name" AS "name" FROM "test" AS "TEST"''', + '''SELECT "test"."name" AS "name" FROM "test" AS "test"''', ), ( "CS table and column and query table but no CS in query column", diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py index b776bdd..448a077 100644 --- a/tests/dialects/test_bigquery.py +++ b/tests/dialects/test_bigquery.py @@ -9,8 +9,11 @@ class TestBigQuery(Validator): maxDiff = None def test_bigquery(self): - self.validate_identity("""SELECT JSON '"foo"' AS json_data""") self.validate_identity("SELECT * FROM tbl FOR SYSTEM_TIME AS OF z") + self.validate_identity( + """SELECT JSON '"foo"' AS json_data""", + """SELECT PARSE_JSON('"foo"') AS json_data""", + ) self.validate_all( """SELECT @@ -176,10 +179,18 @@ WHERE write={ "bigquery": "SELECT TIMESTAMP_ADD(CAST('2008-12-25 15:30:00+00' AS TIMESTAMP), INTERVAL 10 MINUTE)", "databricks": "SELECT DATEADD(MINUTE, 10, CAST('2008-12-25 15:30:00+00' AS TIMESTAMP))", + "mysql": "SELECT DATE_ADD(TIMESTAMP('2008-12-25 15:30:00+00'), INTERVAL 10 MINUTE)", "spark": "SELECT DATEADD(MINUTE, 10, CAST('2008-12-25 15:30:00+00' AS TIMESTAMP))", }, ) self.validate_all( + 'SELECT TIMESTAMP_SUB(TIMESTAMP "2008-12-25 15:30:00+00", INTERVAL 10 MINUTE)', + write={ + "bigquery": "SELECT TIMESTAMP_SUB(CAST('2008-12-25 15:30:00+00' AS TIMESTAMP), INTERVAL 10 MINUTE)", + "mysql": "SELECT DATE_SUB(TIMESTAMP('2008-12-25 15:30:00+00'), INTERVAL 10 MINUTE)", + }, + ) + self.validate_all( "MD5(x)", write={ "": "MD5_DIGEST(x)", @@ -796,7 +807,7 @@ WHERE ) self.validate_identity( """SELECT JSON_OBJECT(['a', 'b'], [JSON '10', JSON '"foo"']) AS json_data""", - """SELECT JSON_OBJECT('a', JSON '10', 'b', JSON '"foo"') AS json_data""", + """SELECT JSON_OBJECT('a', PARSE_JSON('10'), 'b', PARSE_JSON('"foo"')) AS json_data""", ) self.validate_identity( "SELECT JSON_OBJECT(['a', 'b'], [STRUCT(10 AS id, 'Red' AS color), STRUCT(20 AS id, 'Blue' AS color)]) AS json_data", diff --git a/tests/dialects/test_duckdb.py b/tests/dialects/test_duckdb.py index c5ee679..36fca7c 100644 --- a/tests/dialects/test_duckdb.py +++ b/tests/dialects/test_duckdb.py @@ -348,6 +348,27 @@ class TestDuckDB(Validator): "SELECT CAST('2020-05-06' AS DATE) + INTERVAL 5 DAY", read={"bigquery": "SELECT DATE_ADD(CAST('2020-05-06' AS DATE), INTERVAL 5 DAY)"}, ) + self.validate_all( + "SELECT QUANTILE_CONT(x, q) FROM t", + write={ + "duckdb": "SELECT QUANTILE_CONT(x, q) FROM t", + "postgres": "SELECT PERCENTILE_CONT(q) WITHIN GROUP (ORDER BY x) FROM t", + }, + ) + self.validate_all( + "SELECT QUANTILE_DISC(x, q) FROM t", + write={ + "duckdb": "SELECT QUANTILE_DISC(x, q) FROM t", + "postgres": "SELECT PERCENTILE_DISC(q) WITHIN GROUP (ORDER BY x) FROM t", + }, + ) + self.validate_all( + "SELECT MEDIAN(x) FROM t", + write={ + "duckdb": "SELECT QUANTILE_CONT(x, 0.5) FROM t", + "postgres": "SELECT PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY x) FROM t", + }, + ) with self.assertRaises(UnsupportedError): transpile( diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py index 6104e3f..e362e9e 100644 --- a/tests/dialects/test_mysql.py +++ b/tests/dialects/test_mysql.py @@ -19,6 +19,7 @@ class TestMySQL(Validator): }, ) + self.validate_identity("CREATE TABLE foo (a BIGINT, UNIQUE (b) USING BTREE)") self.validate_identity("CREATE TABLE foo (id BIGINT)") self.validate_identity("CREATE TABLE 00f (1d BIGINT)") self.validate_identity("UPDATE items SET items.price = 0 WHERE items.id >= 5 LIMIT 10") @@ -107,9 +108,8 @@ class TestMySQL(Validator): ) def test_identity(self): - self.validate_identity( - "SELECT * FROM x ORDER BY BINARY a", "SELECT * FROM x ORDER BY CAST(a AS BINARY)" - ) + self.validate_identity("UNLOCK TABLES") + self.validate_identity("LOCK TABLES `app_fields` WRITE") self.validate_identity("SELECT 1 XOR 0") self.validate_identity("SELECT 1 && 0", "SELECT 1 AND 0") self.validate_identity("SELECT /*+ BKA(t1) NO_BKA(t2) */ * FROM t1 INNER JOIN t2") @@ -134,6 +134,9 @@ class TestMySQL(Validator): self.validate_identity("SELECT * FROM t1, t2 FOR SHARE OF t1, t2 SKIP LOCKED") self.validate_identity("SELECT a || b", "SELECT a OR b") self.validate_identity( + "SELECT * FROM x ORDER BY BINARY a", "SELECT * FROM x ORDER BY CAST(a AS BINARY)" + ) + self.validate_identity( """SELECT * FROM foo WHERE 3 MEMBER OF(JSON_EXTRACT(info, '$.value'))""" ) self.validate_identity( @@ -546,7 +549,7 @@ class TestMySQL(Validator): "oracle": "SELECT a FROM tbl FOR UPDATE", "postgres": "SELECT a FROM tbl FOR UPDATE", "redshift": "SELECT a FROM tbl", - "tsql": "SELECT a FROM tbl FOR UPDATE", + "tsql": "SELECT a FROM tbl", }, ) self.validate_all( @@ -556,7 +559,7 @@ class TestMySQL(Validator): "mysql": "SELECT a FROM tbl FOR SHARE", "oracle": "SELECT a FROM tbl FOR SHARE", "postgres": "SELECT a FROM tbl FOR SHARE", - "tsql": "SELECT a FROM tbl FOR SHARE", + "tsql": "SELECT a FROM tbl", }, ) self.validate_all( @@ -868,3 +871,17 @@ COMMENT='客户账户表'""" def test_json_object(self): self.validate_identity("SELECT JSON_OBJECT('id', 87, 'name', 'carrot')") + + def test_is_null(self): + self.validate_all( + "SELECT ISNULL(x)", write={"": "SELECT (x IS NULL)", "mysql": "SELECT (x IS NULL)"} + ) + + def test_monthname(self): + self.validate_all( + "MONTHNAME(x)", + write={ + "": "TIME_TO_STR(x, '%B')", + "mysql": "DATE_FORMAT(x, '%M')", + }, + ) diff --git a/tests/dialects/test_oracle.py b/tests/dialects/test_oracle.py index 2dfd179..675ee8a 100644 --- a/tests/dialects/test_oracle.py +++ b/tests/dialects/test_oracle.py @@ -6,6 +6,7 @@ class TestOracle(Validator): dialect = "oracle" def test_oracle(self): + self.validate_identity("SELECT x FROM t WHERE cond FOR UPDATE") self.validate_identity("SELECT JSON_OBJECT(k1: v1 FORMAT JSON, k2: v2 FORMAT JSON)") self.validate_identity("SELECT JSON_OBJECT('name': first_name || ' ' || last_name) FROM t") self.validate_identity("COALESCE(c1, c2, c3)") @@ -50,6 +51,14 @@ class TestOracle(Validator): "SELECT UNIQUE col1, col2 FROM table", "SELECT DISTINCT col1, col2 FROM table", ) + self.validate_identity( + "SELECT * FROM T ORDER BY I OFFSET nvl(:variable1, 10) ROWS FETCH NEXT nvl(:variable2, 10) ROWS ONLY", + "SELECT * FROM T ORDER BY I OFFSET COALESCE(:variable1, 10) ROWS FETCH NEXT COALESCE(:variable2, 10) ROWS ONLY", + ) + self.validate_identity( + "SELECT * FROM t SAMPLE (.25)", + "SELECT * FROM t SAMPLE (0.25)", + ) self.validate_all( "NVL(NULL, 1)", @@ -82,6 +91,16 @@ class TestOracle(Validator): "": "CAST(x AS FLOAT)", }, ) + self.validate_all( + "CAST(x AS sch.udt)", + read={ + "postgres": "CAST(x AS sch.udt)", + }, + write={ + "oracle": "CAST(x AS sch.udt)", + "postgres": "CAST(x AS sch.udt)", + }, + ) def test_join_marker(self): self.validate_identity("SELECT e1.x, e2.x FROM e e1, e e2 WHERE e1.y (+) = e2.y") @@ -218,3 +237,25 @@ INNER JOIN JSON_TABLE(:emps, '$[*]' COLUMNS (empno NUMBER PATH '$')) jt ON ar.empno = jt.empno""", pretty=True, ) + + def test_connect_by(self): + start = "START WITH last_name = 'King'" + connect = "CONNECT BY PRIOR employee_id = manager_id AND LEVEL <= 4" + body = """ + SELECT last_name "Employee", + LEVEL, SYS_CONNECT_BY_PATH(last_name, '/') "Path" + FROM employees + WHERE level <= 3 AND department_id = 80 + """ + pretty = """SELECT + last_name AS "Employee", + LEVEL, + SYS_CONNECT_BY_PATH(last_name, '/') AS "Path" +FROM employees +START WITH last_name = 'King' +CONNECT BY PRIOR employee_id = manager_id AND LEVEL <= 4 +WHERE + level <= 3 AND department_id = 80""" + + for query in (f"{body}{start}{connect}", f"{body}{connect}{start}"): + self.validate_identity(query, pretty, pretty=True) diff --git a/tests/dialects/test_postgres.py b/tests/dialects/test_postgres.py index 151f3af..285496a 100644 --- a/tests/dialects/test_postgres.py +++ b/tests/dialects/test_postgres.py @@ -207,6 +207,7 @@ class TestPostgres(Validator): "SELECT PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY amount)", write={ "databricks": "SELECT PERCENTILE_APPROX(amount, 0.5)", + "postgres": "SELECT PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY amount)", "presto": "SELECT APPROX_PERCENTILE(amount, 0.5)", "spark": "SELECT PERCENTILE_APPROX(amount, 0.5)", "trino": "SELECT APPROX_PERCENTILE(amount, 0.5)", diff --git a/tests/dialects/test_presto.py b/tests/dialects/test_presto.py index dbca5b3..a92f04f 100644 --- a/tests/dialects/test_presto.py +++ b/tests/dialects/test_presto.py @@ -358,6 +358,10 @@ class TestPresto(Validator): write={"presto": "CAST(x AS TIMESTAMP)"}, read={"mysql": "CAST(x AS DATETIME)", "clickhouse": "CAST(x AS DATETIME64)"}, ) + self.validate_all( + "CAST(x AS TIMESTAMP)", + read={"mysql": "TIMESTAMP(x)"}, + ) def test_ddl(self): self.validate_all( @@ -518,6 +522,14 @@ class TestPresto(Validator): ) self.validate_all( + """JSON '"foo"'""", + write={ + "bigquery": """PARSE_JSON('"foo"')""", + "presto": """JSON_PARSE('"foo"')""", + "snowflake": """PARSE_JSON('"foo"')""", + }, + ) + self.validate_all( "SELECT ROW(1, 2)", read={ "spark": "SELECT STRUCT(1, 2)", @@ -824,9 +836,9 @@ class TestPresto(Validator): self.validate_all( """JSON_FORMAT(JSON '"x"')""", write={ - "bigquery": """TO_JSON_STRING(JSON '"x"')""", - "duckdb": """CAST(TO_JSON(CAST('"x"' AS JSON)) AS TEXT)""", - "presto": """JSON_FORMAT(CAST('"x"' AS JSON))""", + "bigquery": """TO_JSON_STRING(PARSE_JSON('"x"'))""", + "duckdb": """CAST(TO_JSON(JSON('"x"')) AS TEXT)""", + "presto": """JSON_FORMAT(JSON_PARSE('"x"'))""", "spark": """REGEXP_EXTRACT(TO_JSON(FROM_JSON('["x"]', SCHEMA_OF_JSON('["x"]'))), '^.(.*).$', 1)""", }, ) @@ -916,14 +928,14 @@ class TestPresto(Validator): "SELECT CAST(JSON '[1,23,456]' AS ARRAY(INTEGER))", write={ "spark": "SELECT FROM_JSON('[1,23,456]', 'ARRAY<INT>')", - "presto": "SELECT CAST(CAST('[1,23,456]' AS JSON) AS ARRAY(INTEGER))", + "presto": "SELECT CAST(JSON_PARSE('[1,23,456]') AS ARRAY(INTEGER))", }, ) self.validate_all( """SELECT CAST(JSON '{"k1":1,"k2":23,"k3":456}' AS MAP(VARCHAR, INTEGER))""", write={ "spark": 'SELECT FROM_JSON(\'{"k1":1,"k2":23,"k3":456}\', \'MAP<STRING, INT>\')', - "presto": 'SELECT CAST(CAST(\'{"k1":1,"k2":23,"k3":456}\' AS JSON) AS MAP(VARCHAR, INTEGER))', + "presto": 'SELECT CAST(JSON_PARSE(\'{"k1":1,"k2":23,"k3":456}\') AS MAP(VARCHAR, INTEGER))', }, ) diff --git a/tests/fixtures/optimizer/optimizer.sql b/tests/fixtures/optimizer/optimizer.sql index 4a994c1..e59f14d 100644 --- a/tests/fixtures/optimizer/optimizer.sql +++ b/tests/fixtures/optimizer/optimizer.sql @@ -577,10 +577,10 @@ FROM `u_cte` AS `u_cte` PIVOT(SUM(`u_cte`.`f`) AS `sum` FOR `u_cte`.`h` IN ('x', # dialect: snowflake SELECT * FROM u PIVOT (SUM(f) FOR h IN ('x', 'y')); SELECT - "_Q_0"."G" AS "G", - "_Q_0"."'x'" AS "'x'", - "_Q_0"."'y'" AS "'y'" -FROM "U" AS "U" PIVOT(SUM("U"."F") FOR "U"."H" IN ('x', 'y')) AS "_Q_0" + "_q_0"."G" AS "G", + "_q_0"."'x'" AS "'x'", + "_q_0"."'y'" AS "'y'" +FROM "U" AS "U" PIVOT(SUM("U"."F") FOR "U"."H" IN ('x', 'y')) AS "_q_0" ; # title: selecting all columns from a pivoted source and generating spark @@ -668,13 +668,13 @@ WHERE GROUP BY `dAy`, `top_term`, rank ORDER BY `DaY` DESC; SELECT - `top_terms`.`refresh_date` AS `day`, - `top_terms`.`term` AS `top_term`, - `top_terms`.`rank` AS `rank` -FROM `bigquery-public-data`.`GooGle_tReNDs`.`TOp_TeRmS` AS `top_terms` + `TOp_TeRmS`.`refresh_date` AS `day`, + `TOp_TeRmS`.`term` AS `top_term`, + `TOp_TeRmS`.`rank` AS `rank` +FROM `bigquery-public-data`.`GooGle_tReNDs`.`TOp_TeRmS` AS `TOp_TeRmS` WHERE - `top_terms`.`rank` = 1 - AND CAST(`top_terms`.`refresh_date` AS DATE) >= DATE_SUB(CURRENT_DATE, INTERVAL 2 WEEK) + `TOp_TeRmS`.`rank` = 1 + AND CAST(`TOp_TeRmS`.`refresh_date` AS DATE) >= DATE_SUB(CURRENT_DATE, INTERVAL 2 WEEK) GROUP BY `day`, `top_term`, diff --git a/tests/test_expressions.py b/tests/test_expressions.py index 5d1f810..b3ce926 100644 --- a/tests/test_expressions.py +++ b/tests/test_expressions.py @@ -934,5 +934,8 @@ FROM foo""", assert dtype.is_type("foo") assert not dtype.is_type("bar") + dtype = exp.DataType.build("a.b.c", udt=True) + assert dtype.is_type("a.b.c") + with self.assertRaises(ParseError): exp.DataType.build("foo") diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index 4415e03..a40f089 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -223,12 +223,12 @@ class TestOptimizer(unittest.TestCase): self.assertEqual( optimizer.qualify.qualify( parse_one( - "WITH X AS (SELECT Y.A FROM DB.Y CROSS JOIN a.b.INFORMATION_SCHEMA.COLUMNS) SELECT `A` FROM X", + "WITH X AS (SELECT Y.A FROM DB.y CROSS JOIN a.b.INFORMATION_SCHEMA.COLUMNS) SELECT `A` FROM X", read="bigquery", ), dialect="bigquery", ).sql(), - 'WITH "x" AS (SELECT "y"."a" AS "a" FROM "DB"."Y" AS "y" CROSS JOIN "a"."b"."INFORMATION_SCHEMA"."COLUMNS" AS "columns") SELECT "x"."a" AS "a" FROM "x"', + 'WITH "x" AS (SELECT "y"."a" AS "a" FROM "DB"."y" AS "y" CROSS JOIN "a"."b"."INFORMATION_SCHEMA"."COLUMNS" AS "COLUMNS") SELECT "x"."a" AS "a" FROM "x"', ) self.assertEqual( @@ -776,6 +776,13 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') self.assertEqual(exp.DataType.Type.ARRAY, expression.selects[0].type.this) self.assertEqual(expression.selects[0].type.sql(), "ARRAY<INT>") + def test_user_defined_type_annotation(self): + schema = MappingSchema({"t": {"x": "int"}}, dialect="postgres") + expression = annotate_types(parse_one("SELECT CAST(x AS IPADDRESS) FROM t"), schema=schema) + + self.assertEqual(exp.DataType.Type.USERDEFINED, expression.selects[0].type.this) + self.assertEqual(expression.selects[0].type.sql(dialect="postgres"), "IPADDRESS") + def test_recursive_cte(self): query = parse_one( """ diff --git a/tests/test_parser.py b/tests/test_parser.py index ad9b941..74463fd 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -433,7 +433,7 @@ class TestParser(unittest.TestCase): self.assertEqual(parse_one("TIMESTAMP(1) WITH TIME ZONE").sql(), "TIMESTAMPTZ(1)") self.assertEqual(parse_one("TIMESTAMP(1) WITH LOCAL TIME ZONE").sql(), "TIMESTAMPLTZ(1)") self.assertEqual(parse_one("TIMESTAMP(1) WITHOUT TIME ZONE").sql(), "TIMESTAMP(1)") - self.assertEqual(parse_one("""JSON '{"x":"y"}'""").sql(), """CAST('{"x":"y"}' AS JSON)""") + self.assertEqual(parse_one("""JSON '{"x":"y"}'""").sql(), """PARSE_JSON('{"x":"y"}')""") self.assertIsInstance(parse_one("TIMESTAMP(1)"), exp.Func) self.assertIsInstance(parse_one("TIMESTAMP('2022-01-01')"), exp.Func) self.assertIsInstance(parse_one("TIMESTAMP()"), exp.Func) |