diff options
Diffstat (limited to 'tests/dialects')
-rw-r--r-- | tests/dialects/test_bigquery.py | 131 | ||||
-rw-r--r-- | tests/dialects/test_clickhouse.py | 1 | ||||
-rw-r--r-- | tests/dialects/test_dialect.py | 108 | ||||
-rw-r--r-- | tests/dialects/test_duckdb.py | 8 | ||||
-rw-r--r-- | tests/dialects/test_postgres.py | 52 | ||||
-rw-r--r-- | tests/dialects/test_snowflake.py | 16 | ||||
-rw-r--r-- | tests/dialects/test_spark.py | 11 |
7 files changed, 302 insertions, 25 deletions
diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py index eeb49f3..c6cfe01 100644 --- a/tests/dialects/test_bigquery.py +++ b/tests/dialects/test_bigquery.py @@ -10,6 +10,7 @@ from sqlglot import ( exp, parse, transpile, + parse_one, ) from sqlglot.helper import logger as helper_logger from sqlglot.parser import logger as parser_logger @@ -21,8 +22,6 @@ class TestBigQuery(Validator): maxDiff = None def test_bigquery(self): - self.validate_identity("REGEXP_EXTRACT(x, '(?<)')") - self.validate_all( "EXTRACT(HOUR FROM DATETIME(2008, 12, 25, 15, 30, 00))", write={ @@ -182,7 +181,6 @@ LANGUAGE js AS self.validate_identity("""CREATE TABLE x (a STRUCT<values ARRAY<INT64>>)""") self.validate_identity("""CREATE TABLE x (a STRUCT<b STRING OPTIONS (description='b')>)""") self.validate_identity("CAST(x AS TIMESTAMP)") - self.validate_identity("REGEXP_EXTRACT(`foo`, 'bar: (.+?)', 1, 1)") self.validate_identity("BEGIN DECLARE y INT64", check_command_warning=True) self.validate_identity("BEGIN TRANSACTION") self.validate_identity("COMMIT TRANSACTION") @@ -202,9 +200,24 @@ LANGUAGE js AS self.validate_identity("CAST(x AS NVARCHAR)", "CAST(x AS STRING)") self.validate_identity("CAST(x AS TIMESTAMPTZ)", "CAST(x AS TIMESTAMP)") self.validate_identity("CAST(x AS RECORD)", "CAST(x AS STRUCT)") - self.validate_identity("EDIT_DISTANCE('a', 'a', max_distance => 2)").assert_is( - exp.Levenshtein + self.validate_all( + "EDIT_DISTANCE(col1, col2, max_distance => 3)", + write={ + "bigquery": "EDIT_DISTANCE(col1, col2, max_distance => 3)", + "clickhouse": UnsupportedError, + "databricks": UnsupportedError, + "drill": UnsupportedError, + "duckdb": UnsupportedError, + "hive": UnsupportedError, + "postgres": "LEVENSHTEIN_LESS_EQUAL(col1, col2, 3)", + "presto": UnsupportedError, + "snowflake": "EDITDISTANCE(col1, col2, 3)", + "spark": UnsupportedError, + "spark2": UnsupportedError, + "sqlite": UnsupportedError, + }, ) + self.validate_identity( "MERGE INTO dataset.NewArrivals USING (SELECT * FROM UNNEST([('microwave', 10, 'warehouse #1'), ('dryer', 30, 'warehouse #1'), ('oven', 20, 'warehouse #2')])) ON FALSE WHEN NOT MATCHED THEN INSERT ROW WHEN NOT MATCHED BY SOURCE THEN DELETE" ) @@ -315,10 +328,6 @@ LANGUAGE js AS "SELECT * FROM (SELECT a, b, c FROM test) PIVOT(SUM(b) AS d, COUNT(*) AS e FOR c IN ('x', 'y'))", ) self.validate_identity( - r"REGEXP_EXTRACT(svc_plugin_output, r'\\\((.*)')", - r"REGEXP_EXTRACT(svc_plugin_output, '\\\\\\((.*)')", - ) - self.validate_identity( "SELECT CAST(1 AS BYTEINT)", "SELECT CAST(1 AS INT64)", ) @@ -1378,14 +1387,6 @@ LANGUAGE js AS "postgres": "SELECT * FROM (VALUES (1)) AS t1(id) CROSS JOIN (VALUES (1)) AS t2(id)", }, ) - - self.validate_all( - "SELECT REGEXP_EXTRACT(abc, 'pattern(group)') FROM table", - write={ - "bigquery": "SELECT REGEXP_EXTRACT(abc, 'pattern(group)') FROM table", - "duckdb": '''SELECT REGEXP_EXTRACT(abc, 'pattern(group)', 1) FROM "table"''', - }, - ) self.validate_all( "SELECT * FROM UNNEST([1]) WITH OFFSET", write={"bigquery": "SELECT * FROM UNNEST([1]) WITH OFFSET AS offset"}, @@ -1602,6 +1603,14 @@ WHERE "snowflake": """SELECT GET_PATH(PARSE_JSON('{"class": {"students": []}}'), 'class')""", }, ) + self.validate_all( + """SELECT JSON_VALUE_ARRAY('{"arr": [1, "a"]}', '$.arr')""", + write={ + "bigquery": """SELECT JSON_VALUE_ARRAY('{"arr": [1, "a"]}', '$.arr')""", + "duckdb": """SELECT CAST('{"arr": [1, "a"]}' -> '$.arr' AS TEXT[])""", + "snowflake": """SELECT TRANSFORM(GET_PATH(PARSE_JSON('{"arr": [1, "a"]}'), 'arr'), x -> CAST(x AS VARCHAR))""", + }, + ) def test_errors(self): with self.assertRaises(TokenError): @@ -2116,3 +2125,91 @@ OPTIONS ( "snowflake": """SELECT JSON_EXTRACT_PATH_TEXT('{"name": "Jakob", "age": "6"}', 'age')""", }, ) + + def test_json_extract_array(self): + for func in ("JSON_QUERY_ARRAY", "JSON_EXTRACT_ARRAY"): + with self.subTest(f"Testing BigQuery's {func}"): + self.validate_all( + f"""SELECT {func}('{{"fruits": [1, "oranges"]}}', '$.fruits')""", + write={ + "bigquery": f"""SELECT {func}('{{"fruits": [1, "oranges"]}}', '$.fruits')""", + "duckdb": """SELECT CAST('{"fruits": [1, "oranges"]}' -> '$.fruits' AS JSON[])""", + "snowflake": """SELECT TRANSFORM(GET_PATH(PARSE_JSON('{"fruits": [1, "oranges"]}'), 'fruits'), x -> PARSE_JSON(TO_JSON(x)))""", + }, + ) + + def test_unix_seconds(self): + self.validate_all( + "SELECT UNIX_SECONDS('2008-12-25 15:30:00+00')", + read={ + "bigquery": "SELECT UNIX_SECONDS('2008-12-25 15:30:00+00')", + "spark": "SELECT UNIX_SECONDS('2008-12-25 15:30:00+00')", + "databricks": "SELECT UNIX_SECONDS('2008-12-25 15:30:00+00')", + }, + write={ + "spark": "SELECT UNIX_SECONDS('2008-12-25 15:30:00+00')", + "databricks": "SELECT UNIX_SECONDS('2008-12-25 15:30:00+00')", + "duckdb": "SELECT DATE_DIFF('SECONDS', CAST('1970-01-01 00:00:00+00' AS TIMESTAMPTZ), '2008-12-25 15:30:00+00')", + "snowflake": "SELECT TIMESTAMPDIFF(SECONDS, CAST('1970-01-01 00:00:00+00' AS TIMESTAMPTZ), '2008-12-25 15:30:00+00')", + }, + ) + + for dialect in ("bigquery", "spark", "databricks"): + parse_one("UNIX_SECONDS(col)", dialect=dialect).assert_is(exp.UnixSeconds) + + def test_regexp_extract(self): + self.validate_identity("REGEXP_EXTRACT(x, '(?<)')") + self.validate_identity("REGEXP_EXTRACT(`foo`, 'bar: (.+?)', 1, 1)") + self.validate_identity( + r"REGEXP_EXTRACT(svc_plugin_output, r'\\\((.*)')", + r"REGEXP_EXTRACT(svc_plugin_output, '\\\\\\((.*)')", + ) + self.validate_identity( + r"REGEXP_SUBSTR(value, pattern, position, occurence)", + r"REGEXP_EXTRACT(value, pattern, position, occurence)", + ) + + self.validate_all( + "SELECT REGEXP_EXTRACT(abc, 'pattern(group)') FROM table", + write={ + "bigquery": "SELECT REGEXP_EXTRACT(abc, 'pattern(group)') FROM table", + "duckdb": '''SELECT REGEXP_EXTRACT(abc, 'pattern(group)', 1) FROM "table"''', + }, + ) + + # The pattern does not capture a group (entire regular expression is extracted) + self.validate_all( + "REGEXP_EXTRACT_ALL('a1_a2a3_a4A5a6', 'a[0-9]')", + read={ + "bigquery": "REGEXP_EXTRACT_ALL('a1_a2a3_a4A5a6', 'a[0-9]')", + "trino": "REGEXP_EXTRACT_ALL('a1_a2a3_a4A5a6', 'a[0-9]')", + "presto": "REGEXP_EXTRACT_ALL('a1_a2a3_a4A5a6', 'a[0-9]')", + "snowflake": "REGEXP_EXTRACT_ALL('a1_a2a3_a4A5a6', 'a[0-9]')", + "duckdb": "REGEXP_EXTRACT_ALL('a1_a2a3_a4A5a6', 'a[0-9]', 0)", + "spark": "REGEXP_EXTRACT_ALL('a1_a2a3_a4A5a6', 'a[0-9]', 0)", + "databricks": "REGEXP_EXTRACT_ALL('a1_a2a3_a4A5a6', 'a[0-9]', 0)", + }, + write={ + "bigquery": "REGEXP_EXTRACT_ALL('a1_a2a3_a4A5a6', 'a[0-9]')", + "trino": "REGEXP_EXTRACT_ALL('a1_a2a3_a4A5a6', 'a[0-9]')", + "presto": "REGEXP_EXTRACT_ALL('a1_a2a3_a4A5a6', 'a[0-9]')", + "snowflake": "REGEXP_EXTRACT_ALL('a1_a2a3_a4A5a6', 'a[0-9]')", + "duckdb": "REGEXP_EXTRACT_ALL('a1_a2a3_a4A5a6', 'a[0-9]', 0)", + "spark": "REGEXP_EXTRACT_ALL('a1_a2a3_a4A5a6', 'a[0-9]', 0)", + "databricks": "REGEXP_EXTRACT_ALL('a1_a2a3_a4A5a6', 'a[0-9]', 0)", + }, + ) + + # The pattern does capture >=1 group (the default is to extract the first instance) + self.validate_all( + "REGEXP_EXTRACT_ALL('a1_a2a3_a4A5a6', '(a)[0-9]')", + write={ + "bigquery": "REGEXP_EXTRACT_ALL('a1_a2a3_a4A5a6', '(a)[0-9]')", + "trino": "REGEXP_EXTRACT_ALL('a1_a2a3_a4A5a6', '(a)[0-9]', 1)", + "presto": "REGEXP_EXTRACT_ALL('a1_a2a3_a4A5a6', '(a)[0-9]', 1)", + "snowflake": "REGEXP_EXTRACT_ALL('a1_a2a3_a4A5a6', '(a)[0-9]', 1, 1, 'c', 1)", + "duckdb": "REGEXP_EXTRACT_ALL('a1_a2a3_a4A5a6', '(a)[0-9]', 1)", + "spark": "REGEXP_EXTRACT_ALL('a1_a2a3_a4A5a6', '(a)[0-9]')", + "databricks": "REGEXP_EXTRACT_ALL('a1_a2a3_a4A5a6', '(a)[0-9]')", + }, + ) diff --git a/tests/dialects/test_clickhouse.py b/tests/dialects/test_clickhouse.py index a0efb54..5a4461e 100644 --- a/tests/dialects/test_clickhouse.py +++ b/tests/dialects/test_clickhouse.py @@ -82,6 +82,7 @@ class TestClickhouse(Validator): self.validate_identity("SELECT histogram(5)(a)") self.validate_identity("SELECT groupUniqArray(2)(a)") self.validate_identity("SELECT exponentialTimeDecayedAvg(60)(a, b)") + self.validate_identity("levenshteinDistance(col1, col2)", "editDistance(col1, col2)") self.validate_identity("SELECT * FROM foo WHERE x GLOBAL IN (SELECT * FROM bar)") self.validate_identity("position(haystack, needle)") self.validate_identity("position(haystack, needle, position)") diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py index 85402e2..170b64b 100644 --- a/tests/dialects/test_dialect.py +++ b/tests/dialects/test_dialect.py @@ -1444,6 +1444,56 @@ class TestDialect(Validator): }, ) + # UNNEST without column alias + self.validate_all( + "SELECT * FROM x CROSS JOIN UNNEST(y) AS t", + write={ + "presto": "SELECT * FROM x CROSS JOIN UNNEST(y) AS t", + "spark": UnsupportedError, + "databricks": UnsupportedError, + }, + ) + + # UNNEST MAP Object into multiple columns, using single alias + self.validate_all( + "SELECT a, b FROM x CROSS JOIN UNNEST(y) AS t (a, b)", + write={ + "presto": "SELECT a, b FROM x CROSS JOIN UNNEST(y) AS t(a, b)", + "spark": "SELECT a, b FROM x LATERAL VIEW EXPLODE(y) t AS a, b", + "hive": "SELECT a, b FROM x LATERAL VIEW EXPLODE(y) t AS a, b", + }, + ) + + # Unnest multiple Expression into respective mapped alias + self.validate_all( + "SELECT numbers, animals, n, a FROM (SELECT ARRAY(2, 5) AS numbers, ARRAY('dog', 'cat', 'bird') AS animals UNION ALL SELECT ARRAY(7, 8, 9), ARRAY('cow', 'pig')) AS x CROSS JOIN UNNEST(numbers, animals) AS t(n, a)", + write={ + "presto": "SELECT numbers, animals, n, a FROM (SELECT ARRAY[2, 5] AS numbers, ARRAY['dog', 'cat', 'bird'] AS animals UNION ALL SELECT ARRAY[7, 8, 9], ARRAY['cow', 'pig']) AS x CROSS JOIN UNNEST(numbers, animals) AS t(n, a)", + "spark": "SELECT numbers, animals, n, a FROM (SELECT ARRAY(2, 5) AS numbers, ARRAY('dog', 'cat', 'bird') AS animals UNION ALL SELECT ARRAY(7, 8, 9), ARRAY('cow', 'pig')) AS x LATERAL VIEW INLINE(ARRAYS_ZIP(numbers, animals)) t AS n, a", + "hive": UnsupportedError, + }, + ) + + # Unnest column to more then 2 alias (STRUCT) + self.validate_all( + "SELECT a, b, c, d, e FROM x CROSS JOIN UNNEST(y) AS t(a, b, c, d)", + write={ + "presto": "SELECT a, b, c, d, e FROM x CROSS JOIN UNNEST(y) AS t(a, b, c, d)", + "spark": UnsupportedError, + "hive": UnsupportedError, + }, + ) + + def test_multiple_chained_unnest(self): + self.validate_all( + "SELECT * FROM x CROSS JOIN UNNEST(a) AS j(lista) CROSS JOIN UNNEST(b) AS k(listb) CROSS JOIN UNNEST(c) AS l(listc)", + write={ + "presto": "SELECT * FROM x CROSS JOIN UNNEST(a) AS j(lista) CROSS JOIN UNNEST(b) AS k(listb) CROSS JOIN UNNEST(c) AS l(listc)", + "spark": "SELECT * FROM x LATERAL VIEW EXPLODE(a) j AS lista LATERAL VIEW EXPLODE(b) k AS listb LATERAL VIEW EXPLODE(c) l AS listc", + "hive": "SELECT * FROM x LATERAL VIEW EXPLODE(a) j AS lista LATERAL VIEW EXPLODE(b) k AS listb LATERAL VIEW EXPLODE(c) l AS listc", + }, + ) + def test_lateral_subquery(self): self.validate_identity( "SELECT art FROM tbl1 INNER JOIN LATERAL (SELECT art FROM tbl2) AS tbl2 ON tbl1.art = tbl2.art" @@ -1761,16 +1811,68 @@ class TestDialect(Validator): ) self.validate_all( "LEVENSHTEIN(col1, col2)", - write={ + read={ "bigquery": "EDIT_DISTANCE(col1, col2)", - "duckdb": "LEVENSHTEIN(col1, col2)", + "clickhouse": "editDistance(col1, col2)", "drill": "LEVENSHTEIN_DISTANCE(col1, col2)", + "duckdb": "LEVENSHTEIN(col1, col2)", + "hive": "LEVENSHTEIN(col1, col2)", + "spark": "LEVENSHTEIN(col1, col2)", + "postgres": "LEVENSHTEIN(col1, col2)", "presto": "LEVENSHTEIN_DISTANCE(col1, col2)", + "snowflake": "EDITDISTANCE(col1, col2)", + "sqlite": "EDITDIST3(col1, col2)", + "trino": "LEVENSHTEIN_DISTANCE(col1, col2)", + }, + write={ + "bigquery": "EDIT_DISTANCE(col1, col2)", + "clickhouse": "editDistance(col1, col2)", + "drill": "LEVENSHTEIN_DISTANCE(col1, col2)", + "duckdb": "LEVENSHTEIN(col1, col2)", "hive": "LEVENSHTEIN(col1, col2)", "spark": "LEVENSHTEIN(col1, col2)", + "postgres": "LEVENSHTEIN(col1, col2)", + "presto": "LEVENSHTEIN_DISTANCE(col1, col2)", + "snowflake": "EDITDISTANCE(col1, col2)", + "sqlite": "EDITDIST3(col1, col2)", + "trino": "LEVENSHTEIN_DISTANCE(col1, col2)", + }, + ) + + self.validate_all( + "LEVENSHTEIN(col1, col2, 1, 2, 3)", + write={ + "bigquery": UnsupportedError, + "clickhouse": UnsupportedError, + "drill": UnsupportedError, + "duckdb": UnsupportedError, + "hive": UnsupportedError, + "spark": UnsupportedError, + "postgres": "LEVENSHTEIN(col1, col2, 1, 2, 3)", + "presto": UnsupportedError, + "snowflake": UnsupportedError, + "sqlite": UnsupportedError, + "trino": UnsupportedError, }, ) self.validate_all( + "LEVENSHTEIN(col1, col2, 1, 2, 3, 4)", + write={ + "bigquery": UnsupportedError, + "clickhouse": UnsupportedError, + "drill": UnsupportedError, + "duckdb": UnsupportedError, + "hive": UnsupportedError, + "spark": UnsupportedError, + "postgres": "LEVENSHTEIN_LESS_EQUAL(col1, col2, 1, 2, 3, 4)", + "presto": UnsupportedError, + "snowflake": UnsupportedError, + "sqlite": UnsupportedError, + "trino": UnsupportedError, + }, + ) + + self.validate_all( "LEVENSHTEIN(coalesce(col1, col2), coalesce(col2, col1))", write={ "bigquery": "EDIT_DISTANCE(COALESCE(col1, col2), COALESCE(col2, col1))", @@ -3007,7 +3109,7 @@ FROM subquery2""", "databricks": f"MEDIAN(x){suffix}", "redshift": f"MEDIAN(x){suffix}", "oracle": f"MEDIAN(x){suffix}", - "clickhouse": f"MEDIAN(x){suffix}", + "clickhouse": f"median(x){suffix}", "postgres": f"PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY x){suffix}", }, ) diff --git a/tests/dialects/test_duckdb.py b/tests/dialects/test_duckdb.py index b59ac9f..3d4fe9c 100644 --- a/tests/dialects/test_duckdb.py +++ b/tests/dialects/test_duckdb.py @@ -852,6 +852,7 @@ class TestDuckDB(Validator): "clickhouse": "DATE_TRUNC('DAY', x)", }, ) + self.validate_identity("EDITDIST3(col1, col2)", "LEVENSHTEIN(col1, col2)") self.validate_identity("SELECT LENGTH(foo)") self.validate_identity("SELECT ARRAY[1, 2, 3]", "SELECT [1, 2, 3]") @@ -1153,6 +1154,7 @@ class TestDuckDB(Validator): self.validate_identity("CAST(x AS BINARY)", "CAST(x AS BLOB)") self.validate_identity("CAST(x AS VARBINARY)", "CAST(x AS BLOB)") self.validate_identity("CAST(x AS LOGICAL)", "CAST(x AS BOOLEAN)") + self.validate_identity("""CAST({'i': 1, 's': 'foo'} AS STRUCT("s" TEXT, "i" INT))""") self.validate_identity( "CAST(ROW(1, ROW(1)) AS STRUCT(number BIGINT, row STRUCT(number BIGINT)))" ) @@ -1162,11 +1164,11 @@ class TestDuckDB(Validator): ) self.validate_identity( "CAST([[STRUCT_PACK(a := 1)]] AS STRUCT(a BIGINT)[][])", - "CAST([[ROW(1)]] AS STRUCT(a BIGINT)[][])", + "CAST([[{'a': 1}]] AS STRUCT(a BIGINT)[][])", ) self.validate_identity( "CAST([STRUCT_PACK(a := 1)] AS STRUCT(a BIGINT)[])", - "CAST([ROW(1)] AS STRUCT(a BIGINT)[])", + "CAST([{'a': 1}] AS STRUCT(a BIGINT)[])", ) self.validate_identity( "STRUCT_PACK(a := 'b')::json", @@ -1174,7 +1176,7 @@ class TestDuckDB(Validator): ) self.validate_identity( "STRUCT_PACK(a := 'b')::STRUCT(a TEXT)", - "CAST(ROW('b') AS STRUCT(a TEXT))", + "CAST({'a': 'b'} AS STRUCT(a TEXT))", ) self.validate_all( diff --git a/tests/dialects/test_postgres.py b/tests/dialects/test_postgres.py index 4b54cd0..ffe08c6 100644 --- a/tests/dialects/test_postgres.py +++ b/tests/dialects/test_postgres.py @@ -51,7 +51,6 @@ class TestPostgres(Validator): self.validate_identity("x$") self.validate_identity("SELECT ARRAY[1, 2, 3]") self.validate_identity("SELECT ARRAY(SELECT 1)") - self.validate_identity("SELECT ARRAY_LENGTH(ARRAY[1, 2, 3], 1)") self.validate_identity("STRING_AGG(x, y)") self.validate_identity("STRING_AGG(x, ',' ORDER BY y)") self.validate_identity("STRING_AGG(x, ',' ORDER BY y DESC)") @@ -683,6 +682,11 @@ class TestPostgres(Validator): """SELECT TRIM(TRAILING ' XXX ' COLLATE "de_DE")""", """SELECT RTRIM(' XXX ' COLLATE "de_DE")""", ) + self.validate_identity("LEVENSHTEIN(col1, col2)") + self.validate_identity("LEVENSHTEIN_LESS_EQUAL(col1, col2, 1)") + self.validate_identity("LEVENSHTEIN(col1, col2, 1, 2, 3)") + self.validate_identity("LEVENSHTEIN_LESS_EQUAL(col1, col2, 1, 2, 3, 4)") + self.validate_all( """'{"a":1,"b":2}'::json->'b'""", write={ @@ -1237,3 +1241,49 @@ CROSS JOIN JSON_ARRAY_ELEMENTS(CAST(JSON_EXTRACT_PATH(tbox, 'boxes') AS JSON)) A self.validate_identity( """SELECT * FROM table1, ROWS FROM (FUNC1(col1) AS alias1("col1" TEXT)) WITH ORDINALITY AS alias3("col3" INT, "col4" TEXT)""" ) + + def test_array_length(self): + self.validate_identity("SELECT ARRAY_LENGTH(ARRAY[1, 2, 3], 1)") + + self.validate_all( + "ARRAY_LENGTH(arr, 1)", + read={ + "bigquery": "ARRAY_LENGTH(arr)", + "duckdb": "ARRAY_LENGTH(arr)", + "presto": "CARDINALITY(arr)", + "drill": "REPEATED_COUNT(arr)", + "teradata": "CARDINALITY(arr)", + "hive": "SIZE(arr)", + "spark2": "SIZE(arr)", + "spark": "SIZE(arr)", + "databricks": "SIZE(arr)", + }, + write={ + "duckdb": "ARRAY_LENGTH(arr, 1)", + "presto": "CARDINALITY(arr)", + "teradata": "CARDINALITY(arr)", + "bigquery": "ARRAY_LENGTH(arr)", + "drill": "REPEATED_COUNT(arr)", + "clickhouse": "LENGTH(arr)", + "hive": "SIZE(arr)", + "spark2": "SIZE(arr)", + "spark": "SIZE(arr)", + "databricks": "SIZE(arr)", + }, + ) + + self.validate_all( + "ARRAY_LENGTH(arr, foo)", + write={ + "duckdb": "ARRAY_LENGTH(arr, foo)", + "hive": UnsupportedError, + "spark2": UnsupportedError, + "spark": UnsupportedError, + "databricks": UnsupportedError, + "presto": UnsupportedError, + "teradata": UnsupportedError, + "bigquery": UnsupportedError, + "drill": UnsupportedError, + "clickhouse": UnsupportedError, + }, + ) diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py index 8357642..e2db661 100644 --- a/tests/dialects/test_snowflake.py +++ b/tests/dialects/test_snowflake.py @@ -946,6 +946,16 @@ WHERE }, ) + self.validate_identity("EDITDISTANCE(col1, col2)") + self.validate_all( + "EDITDISTANCE(col1, col2, 3)", + write={ + "bigquery": "EDIT_DISTANCE(col1, col2, max_distance => 3)", + "postgres": "LEVENSHTEIN_LESS_EQUAL(col1, col2, 3)", + "snowflake": "EDITDISTANCE(col1, col2, 3)", + }, + ) + def test_null_treatment(self): self.validate_all( r"SELECT FIRST_VALUE(TABLE1.COLUMN1) OVER (PARTITION BY RANDOM_COLUMN1, RANDOM_COLUMN2 ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS MY_ALIAS FROM TABLE1", @@ -1788,7 +1798,6 @@ FROM persons AS p, LATERAL FLATTEN(input => p.c, path => 'contact') AS _flattene self.validate_all( "REGEXP_SUBSTR(subject, pattern, 1, 1, 'c', group)", read={ - "bigquery": "REGEXP_SUBSTR(subject, pattern, 1, 1, 'c', group)", "duckdb": "REGEXP_EXTRACT(subject, pattern, group)", "hive": "REGEXP_EXTRACT(subject, pattern, group)", "presto": "REGEXP_EXTRACT(subject, pattern, group)", @@ -1797,6 +1806,11 @@ FROM persons AS p, LATERAL FLATTEN(input => p.c, path => 'contact') AS _flattene }, ) + self.validate_identity( + "REGEXP_SUBSTR_ALL(subject, pattern)", + "REGEXP_EXTRACT_ALL(subject, pattern)", + ) + @mock.patch("sqlglot.generator.logger") def test_regexp_replace(self, logger): self.validate_all( diff --git a/tests/dialects/test_spark.py b/tests/dialects/test_spark.py index 486bf79..1aa5c21 100644 --- a/tests/dialects/test_spark.py +++ b/tests/dialects/test_spark.py @@ -754,6 +754,17 @@ TBLPROPERTIES ( }, ) + self.validate_all( + "SELECT TIMESTAMPDIFF(MONTH, foo, bar)", + read={ + "databricks": "SELECT TIMESTAMPDIFF(MONTH, foo, bar)", + }, + write={ + "spark": "SELECT TIMESTAMPDIFF(MONTH, foo, bar)", + "databricks": "SELECT TIMESTAMPDIFF(MONTH, foo, bar)", + }, + ) + def test_bool_or(self): self.validate_all( "SELECT a, LOGICAL_OR(b) FROM table GROUP BY a", |