diff options
Diffstat (limited to 'tests/dialects')
-rw-r--r-- | tests/dialects/test_bigquery.py | 11 | ||||
-rw-r--r-- | tests/dialects/test_dialect.py | 58 | ||||
-rw-r--r-- | tests/dialects/test_duckdb.py | 1 | ||||
-rw-r--r-- | tests/dialects/test_hive.py | 24 | ||||
-rw-r--r-- | tests/dialects/test_mysql.py | 5 | ||||
-rw-r--r-- | tests/dialects/test_oracle.py | 18 | ||||
-rw-r--r-- | tests/dialects/test_postgres.py | 5 | ||||
-rw-r--r-- | tests/dialects/test_presto.py | 6 | ||||
-rw-r--r-- | tests/dialects/test_snowflake.py | 2 | ||||
-rw-r--r-- | tests/dialects/test_spark.py | 1 | ||||
-rw-r--r-- | tests/dialects/test_sqlite.py | 5 | ||||
-rw-r--r-- | tests/dialects/test_tsql.py | 19 |
12 files changed, 147 insertions, 8 deletions
diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py index e731b50..e210292 100644 --- a/tests/dialects/test_bigquery.py +++ b/tests/dialects/test_bigquery.py @@ -6,6 +6,8 @@ class TestBigQuery(Validator): dialect = "bigquery" def test_bigquery(self): + self.validate_identity("SELECT AS STRUCT 1 AS a, 2 AS b") + self.validate_identity("SELECT AS VALUE STRUCT(1 AS a, 2 AS b)") self.validate_identity("SELECT STRUCT<ARRAY<STRING>>(['2023-01-17'])") self.validate_identity("SELECT * FROM q UNPIVOT(values FOR quarter IN (b, c))") self.validate_identity( @@ -13,6 +15,15 @@ class TestBigQuery(Validator): ) self.validate_all("LEAST(x, y)", read={"sqlite": "MIN(x, y)"}) + self.validate_all("CAST(x AS CHAR)", write={"bigquery": "CAST(x AS STRING)"}) + self.validate_all("CAST(x AS NCHAR)", write={"bigquery": "CAST(x AS STRING)"}) + self.validate_all("CAST(x AS NVARCHAR)", write={"bigquery": "CAST(x AS STRING)"}) + self.validate_all( + "SELECT ARRAY(SELECT AS STRUCT 1 a, 2 b)", + write={ + "bigquery": "SELECT ARRAY(SELECT AS STRUCT 1 AS a, 2 AS b)", + }, + ) self.validate_all( "REGEXP_CONTAINS('foo', '.*')", read={"bigquery": "REGEXP_CONTAINS('foo', '.*')"}, diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py index 6214c43..0805e9c 100644 --- a/tests/dialects/test_dialect.py +++ b/tests/dialects/test_dialect.py @@ -812,11 +812,13 @@ class TestDialect(Validator): self.validate_all( "JSON_EXTRACT(x, 'y')", read={ + "mysql": "JSON_EXTRACT(x, 'y')", "postgres": "x->'y'", "presto": "JSON_EXTRACT(x, 'y')", "starrocks": "x -> 'y'", }, write={ + "mysql": "JSON_EXTRACT(x, 'y')", "oracle": "JSON_EXTRACT(x, 'y')", "postgres": "x -> 'y'", "presto": "JSON_EXTRACT(x, 'y')", @@ -835,6 +837,17 @@ class TestDialect(Validator): }, ) self.validate_all( + "JSON_EXTRACT_SCALAR(stream_data, '$.data.results')", + read={ + "hive": "GET_JSON_OBJECT(stream_data, '$.data.results')", + "mysql": "stream_data ->> '$.data.results'", + }, + write={ + "hive": "GET_JSON_OBJECT(stream_data, '$.data.results')", + "mysql": "stream_data ->> '$.data.results'", + }, + ) + self.validate_all( "JSONB_EXTRACT(x, 'y')", read={ "postgres": "x#>'y'", @@ -1000,6 +1013,7 @@ class TestDialect(Validator): self.validate_identity("some.column LIKE 'foo' || another.column || 'bar' || LOWER(x)") self.validate_identity("some.column LIKE 'foo' + another.column + 'bar'") + self.validate_all("LIKE(x, 'z')", write={"": "'z' LIKE x"}) self.validate_all( "x ILIKE '%y'", read={ @@ -1196,9 +1210,13 @@ class TestDialect(Validator): ) self.validate_all( "SELECT x FROM y LIMIT 10", + read={ + "tsql": "SELECT TOP 10 x FROM y", + }, write={ "sqlite": "SELECT x FROM y LIMIT 10", "oracle": "SELECT x FROM y FETCH FIRST 10 ROWS ONLY", + "tsql": "SELECT x FROM y FETCH FIRST 10 ROWS ONLY", }, ) self.validate_all( @@ -1493,6 +1511,46 @@ SELECT }, ) + def test_logarithm(self): + self.validate_all( + "LOG(x)", + read={ + "duckdb": "LOG(x)", + "postgres": "LOG(x)", + "redshift": "LOG(x)", + "sqlite": "LOG(x)", + "teradata": "LOG(x)", + }, + ) + self.validate_all( + "LN(x)", + read={ + "bigquery": "LOG(x)", + "clickhouse": "LOG(x)", + "databricks": "LOG(x)", + "drill": "LOG(x)", + "hive": "LOG(x)", + "mysql": "LOG(x)", + "tsql": "LOG(x)", + }, + ) + self.validate_all( + "LOG(b, n)", + read={ + "bigquery": "LOG(n, b)", + "databricks": "LOG(b, n)", + "drill": "LOG(b, n)", + "hive": "LOG(b, n)", + "mysql": "LOG(b, n)", + "oracle": "LOG(b, n)", + "postgres": "LOG(b, n)", + "snowflake": "LOG(b, n)", + "spark": "LOG(b, n)", + "sqlite": "LOG(b, n)", + "tsql": "LOG(n, b)", + }, + ) + def test_count_if(self): self.validate_identity("COUNT_IF(DISTINCT cond)") diff --git a/tests/dialects/test_duckdb.py b/tests/dialects/test_duckdb.py index 1cabade..a15e6b4 100644 --- a/tests/dialects/test_duckdb.py +++ b/tests/dialects/test_duckdb.py @@ -125,6 +125,7 @@ class TestDuckDB(Validator): "SELECT a['x space'] FROM (SELECT {'x space': 1, 'y': 2, 'z': 3} AS a)" ) + self.validate_all("x ~ y", write={"duckdb": "REGEXP_MATCHES(x, y)"}) self.validate_all("SELECT * FROM 'x.y'", write={"duckdb": 'SELECT * FROM "x.y"'}) self.validate_all( "WITH 'x' AS (SELECT 1) SELECT * FROM x", diff --git a/tests/dialects/test_hive.py b/tests/dialects/test_hive.py index 8484805..0161f1e 100644 --- a/tests/dialects/test_hive.py +++ b/tests/dialects/test_hive.py @@ -247,6 +247,30 @@ class TestHive(Validator): def test_time(self): self.validate_all( + "(UNIX_TIMESTAMP(y) - UNIX_TIMESTAMP(x)) * 1000", + read={ + "presto": "DATE_DIFF('millisecond', x, y)", + }, + ) + self.validate_all( + "UNIX_TIMESTAMP(y) - UNIX_TIMESTAMP(x)", + read={ + "presto": "DATE_DIFF('second', x, y)", + }, + ) + self.validate_all( + "(UNIX_TIMESTAMP(y) - UNIX_TIMESTAMP(x)) / 60", + read={ + "presto": "DATE_DIFF('minute', x, y)", + }, + ) + self.validate_all( + "(UNIX_TIMESTAMP(y) - UNIX_TIMESTAMP(x)) / 3600", + read={ + "presto": "DATE_DIFF('hour', x, y)", + }, + ) + self.validate_all( "DATEDIFF(a, b)", write={ "duckdb": "DATE_DIFF('day', CAST(b AS DATE), CAST(a AS DATE))", diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py index 5f8560a..5059d05 100644 --- a/tests/dialects/test_mysql.py +++ b/tests/dialects/test_mysql.py @@ -16,6 +16,7 @@ class TestMySQL(Validator): ) def test_identity(self): + self.validate_identity("x ->> '$.name'") self.validate_identity("SELECT CAST(`a`.`b` AS INT) FROM foo") self.validate_identity("SELECT TRIM(LEADING 'bla' FROM ' XXX ')") self.validate_identity("SELECT TRIM(TRAILING 'bla' FROM ' XXX ')") @@ -424,6 +425,10 @@ COMMENT='客户账户表'""" show = self.validate_identity("SHOW INDEX FROM foo FROM bar") self.assertEqual(show.text("db"), "bar") + self.validate_all( + "SHOW INDEX FROM bar.foo", write={"mysql": "SHOW INDEX FROM foo FROM bar"} + ) + def test_show_db_like_or_where_sql(self): for key in [ "OPEN TABLES", diff --git a/tests/dialects/test_oracle.py b/tests/dialects/test_oracle.py index 4dc3f1b..80fa0f1 100644 --- a/tests/dialects/test_oracle.py +++ b/tests/dialects/test_oracle.py @@ -12,6 +12,24 @@ class TestOracle(Validator): self.validate_identity("SELECT e1.x, e2.x FROM e e1, e e2 WHERE e1.y = e2.y (+)") self.validate_identity("SELECT e1.x, e2.x FROM e e1, e e2 WHERE e1.y (+) = e2.y (+)") + def test_hints(self): + self.validate_identity("SELECT /*+ USE_NL(A B) */ A.COL_TEST FROM TABLE_A A, TABLE_B B") + self.validate_identity( + "SELECT /*+ INDEX(v.j jhist_employee_ix (employee_id start_date)) */ * FROM v" + ) + self.validate_identity( + "SELECT /*+ USE_NL(A B C) */ A.COL_TEST FROM TABLE_A A, TABLE_B B, TABLE_C C" + ) + self.validate_identity( + "SELECT /*+ NO_INDEX(employees emp_empid) */ employee_id FROM employees WHERE employee_id > 200" + ) + self.validate_identity( + "SELECT /*+ NO_INDEX_FFS(items item_order_ix) */ order_id FROM order_items items" + ) + self.validate_identity( + "SELECT /*+ LEADING(e j) */ * FROM employees e, departments d, job_history j WHERE e.department_id = d.department_id AND e.hire_date = j.start_date" + ) + def test_xml_table(self): self.validate_identity("XMLTABLE('x')") self.validate_identity("XMLTABLE('x' RETURNING SEQUENCE BY REF)") diff --git a/tests/dialects/test_postgres.py b/tests/dialects/test_postgres.py index c8dea95..a89ae30 100644 --- a/tests/dialects/test_postgres.py +++ b/tests/dialects/test_postgres.py @@ -194,8 +194,9 @@ class TestPostgres(Validator): write={ "postgres": "SELECT * FROM x FETCH FIRST 1 ROWS ONLY", "presto": "SELECT * FROM x FETCH FIRST 1 ROWS ONLY", - "hive": "SELECT * FROM x FETCH FIRST 1 ROWS ONLY", - "spark": "SELECT * FROM x FETCH FIRST 1 ROWS ONLY", + "hive": "SELECT * FROM x LIMIT 1", + "spark": "SELECT * FROM x LIMIT 1", + "sqlite": "SELECT * FROM x LIMIT 1", }, ) self.validate_all( diff --git a/tests/dialects/test_presto.py b/tests/dialects/test_presto.py index 0a9111c..1762e7a 100644 --- a/tests/dialects/test_presto.py +++ b/tests/dialects/test_presto.py @@ -370,6 +370,12 @@ class TestPresto(Validator): self.validate_identity("APPROX_PERCENTILE(a, b, c, d)") self.validate_all( + "SELECT JSON_OBJECT(KEY 'key1' VALUE 1, KEY 'key2' VALUE TRUE)", + write={ + "presto": "SELECT JSON_OBJECT('key1': 1, 'key2': TRUE)", + }, + ) + self.validate_all( "ARRAY_AGG(x ORDER BY y DESC)", write={ "hive": "COLLECT_LIST(x)", diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py index 5f6efce..940fa50 100644 --- a/tests/dialects/test_snowflake.py +++ b/tests/dialects/test_snowflake.py @@ -530,6 +530,7 @@ class TestSnowflake(Validator): "snowflake": "DATEADD(DAY, 5, CAST('2008-12-25' AS DATE))", }, ) + self.validate_identity("DATEDIFF(DAY, 5, CAST('2008-12-25' AS DATE))") def test_semi_structured_types(self): self.validate_identity("SELECT CAST(a AS VARIANT)") @@ -814,6 +815,7 @@ FROM persons AS p, LATERAL FLATTEN(input => p.c, path => 'contact') AS f, LATERA self.assertIsInstance(like, exp.LikeAny) self.assertIsInstance(ilike, exp.ILikeAny) + like.sql() # check that this doesn't raise def test_match_recognize(self): for row in ( diff --git a/tests/dialects/test_spark.py b/tests/dialects/test_spark.py index 5b21349..b12f272 100644 --- a/tests/dialects/test_spark.py +++ b/tests/dialects/test_spark.py @@ -212,6 +212,7 @@ TBLPROPERTIES ( self.validate_identity("TRIM(BOTH 'SL' FROM 'SSparkSQLS')") self.validate_identity("TRIM(LEADING 'SL' FROM 'SSparkSQLS')") self.validate_identity("TRIM(TRAILING 'SL' FROM 'SSparkSQLS')") + self.validate_identity("SPLIT(str, pattern, lim)") self.validate_all( "CAST(x AS TIMESTAMP)", read={"trino": "CAST(x AS TIMESTAMP(6) WITH TIME ZONE)"} diff --git a/tests/dialects/test_sqlite.py b/tests/dialects/test_sqlite.py index 98c4a79..fd9e52b 100644 --- a/tests/dialects/test_sqlite.py +++ b/tests/dialects/test_sqlite.py @@ -56,6 +56,11 @@ class TestSQLite(Validator): ) def test_sqlite(self): + self.validate_all("SELECT LIKE(y, x)", write={"sqlite": "SELECT x LIKE y"}) + self.validate_all("SELECT GLOB('*y*', 'xyz')", write={"sqlite": "SELECT 'xyz' GLOB '*y*'"}) + self.validate_all( + "SELECT LIKE('%y%', 'xyz', '')", write={"sqlite": "SELECT 'xyz' LIKE '%y%' ESCAPE ''"} + ) self.validate_all( "CURRENT_DATE", read={ diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py index 4224a1e..60867be 100644 --- a/tests/dialects/test_tsql.py +++ b/tests/dialects/test_tsql.py @@ -7,6 +7,7 @@ class TestTSQL(Validator): def test_tsql(self): self.validate_identity("SELECT CASE WHEN a > 1 THEN b END") + self.validate_identity("SELECT * FROM taxi ORDER BY 1 OFFSET 0 ROWS FETCH NEXT 3 ROWS ONLY") self.validate_identity("END") self.validate_identity("@x") self.validate_identity("#x") @@ -567,15 +568,21 @@ WHERE write={"spark": "LAST_DAY(ADD_MONTHS(CURRENT_TIMESTAMP(), -1))"}, ) - def test_variables(self): - # In TSQL @, # can be used as a prefix for variables/identifiers - expr = parse_one("@x", read="tsql") - self.assertIsInstance(expr, exp.Column) - self.assertIsInstance(expr.this, exp.Identifier) - + def test_identifier_prefixes(self): expr = parse_one("#x", read="tsql") self.assertIsInstance(expr, exp.Column) self.assertIsInstance(expr.this, exp.Identifier) + self.assertEqual(expr.sql("tsql"), "#x") + + expr = parse_one("@x", read="tsql") + self.assertIsInstance(expr, exp.Parameter) + self.assertIsInstance(expr.this, exp.Var) + self.assertEqual(expr.sql("tsql"), "@x") + + table = parse_one("select * from @x", read="tsql").args["from"].expressions[0] + self.assertIsInstance(table, exp.Table) + self.assertIsInstance(table.this, exp.Parameter) + self.assertIsInstance(table.this.this, exp.Var) def test_system_time(self): self.validate_all( |