diff options
Diffstat (limited to '')
24 files changed, 574 insertions, 82 deletions
diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py index 1f5f902..8d172ea 100644 --- a/tests/dialects/test_bigquery.py +++ b/tests/dialects/test_bigquery.py @@ -122,6 +122,14 @@ class TestBigQuery(Validator): """SELECT JSON '"foo"' AS json_data""", """SELECT PARSE_JSON('"foo"') AS json_data""", ) + self.validate_identity( + "CREATE OR REPLACE TABLE `a.b.c` CLONE `a.b.d`", + "CREATE OR REPLACE TABLE a.b.c CLONE a.b.d", + ) + self.validate_identity( + "SELECT * FROM UNNEST(x) WITH OFFSET EXCEPT DISTINCT SELECT * FROM UNNEST(y) WITH OFFSET", + "SELECT * FROM UNNEST(x) WITH OFFSET AS offset EXCEPT DISTINCT SELECT * FROM UNNEST(y) WITH OFFSET AS offset", + ) self.validate_all("SELECT SPLIT(foo)", write={"bigquery": "SELECT SPLIT(foo, ',')"}) self.validate_all("SELECT 1 AS hash", write={"bigquery": "SELECT 1 AS `hash`"}) @@ -131,6 +139,35 @@ class TestBigQuery(Validator): self.validate_all("x <> ''''''", write={"bigquery": "x <> ''"}) self.validate_all("CAST(x AS DATETIME)", read={"": "x::timestamp"}) self.validate_all( + "SELECT '\\n'", + read={ + "bigquery": "SELECT '''\n'''", + }, + write={ + "bigquery": "SELECT '\\n'", + "postgres": "SELECT '\n'", + }, + ) + self.validate_all( + "TRIM(item, '*')", + read={ + "snowflake": "TRIM(item, '*')", + "spark": "TRIM('*', item)", + }, + write={ + "bigquery": "TRIM(item, '*')", + "snowflake": "TRIM(item, '*')", + "spark": "TRIM('*' FROM item)", + }, + ) + self.validate_all( + "CREATE OR REPLACE TABLE `a.b.c` COPY `a.b.d`", + write={ + "bigquery": "CREATE OR REPLACE TABLE a.b.c COPY a.b.d", + "snowflake": "CREATE OR REPLACE TABLE a.b.c CLONE a.b.d", + }, + ) + self.validate_all( "SELECT DATETIME_DIFF('2023-01-01T00:00:00', '2023-01-01T05:00:00', MILLISECOND)", write={ "bigquery": "SELECT DATETIME_DIFF('2023-01-01T00:00:00', '2023-01-01T05:00:00', MILLISECOND)", @@ -608,6 +645,9 @@ class TestBigQuery(Validator): "postgres": "CURRENT_DATE AT TIME ZONE 'UTC'", }, ) + self.validate_identity( + "SELECT * FROM test QUALIFY a IS DISTINCT FROM b WINDOW c AS (PARTITION BY d)" + ) self.validate_all( "SELECT a FROM test WHERE a = 1 GROUP BY a HAVING a = 2 QUALIFY z ORDER BY a LIMIT 10", write={ diff --git a/tests/dialects/test_clickhouse.py b/tests/dialects/test_clickhouse.py index 40a270e..948c00e 100644 --- a/tests/dialects/test_clickhouse.py +++ b/tests/dialects/test_clickhouse.py @@ -24,6 +24,9 @@ class TestClickhouse(Validator): self.assertEqual(expr.sql(dialect="clickhouse"), "COUNT(x)") self.assertIsNone(expr._meta) + self.validate_identity("SELECT * FROM (SELECT a FROM b SAMPLE 0.01)") + self.validate_identity("SELECT * FROM (SELECT a FROM b SAMPLE 1 / 10 OFFSET 1 / 2)") + self.validate_identity("SELECT sum(foo * bar) FROM bla SAMPLE 10000000") self.validate_identity("CAST(x AS Nested(ID UInt32, Serial UInt32, EventTime DATETIME))") self.validate_identity("CAST(x AS Enum('hello' = 1, 'world' = 2))") self.validate_identity("CAST(x AS Enum('hello', 'world'))") @@ -83,6 +86,16 @@ class TestClickhouse(Validator): ) self.validate_all( + "SELECT '\\0'", + read={ + "mysql": "SELECT '\0'", + }, + write={ + "clickhouse": "SELECT '\\0'", + "mysql": "SELECT '\0'", + }, + ) + self.validate_all( "DATE_ADD('day', 1, x)", read={ "clickhouse": "dateAdd(day, 1, x)", @@ -224,6 +237,33 @@ class TestClickhouse(Validator): self.validate_identity( "SELECT s, arr_external FROM arrays_test ARRAY JOIN [1, 2, 3] AS arr_external" ) + self.validate_all( + "SELECT quantile(0.5)(a)", + read={"duckdb": "SELECT quantile(a, 0.5)"}, + write={"clickhouse": "SELECT quantile(0.5)(a)"}, + ) + self.validate_all( + "SELECT quantiles(0.5, 0.4)(a)", + read={"duckdb": "SELECT quantile(a, [0.5, 0.4])"}, + write={"clickhouse": "SELECT quantiles(0.5, 0.4)(a)"}, + ) + self.validate_all( + "SELECT quantiles(0.5)(a)", + read={"duckdb": "SELECT quantile(a, [0.5])"}, + write={"clickhouse": "SELECT quantiles(0.5)(a)"}, + ) + + self.validate_identity("SELECT isNaN(x)") + self.validate_all( + "SELECT IS_NAN(x), ISNAN(x)", + write={"clickhouse": "SELECT isNaN(x), isNaN(x)"}, + ) + + self.validate_identity("SELECT startsWith('a', 'b')") + self.validate_all( + "SELECT STARTS_WITH('a', 'b'), STARTSWITH('a', 'b')", + write={"clickhouse": "SELECT startsWith('a', 'b'), startsWith('a', 'b')"}, + ) def test_cte(self): self.validate_identity("WITH 'x' AS foo SELECT foo") @@ -305,6 +345,9 @@ class TestClickhouse(Validator): def test_ddl(self): self.validate_identity( + 'CREATE TABLE data5 ("x" UInt32, "y" UInt32) ENGINE=MergeTree ORDER BY (round(y / 1000000000), cityHash64(x)) SAMPLE BY cityHash64(x)' + ) + self.validate_identity( "CREATE TABLE foo (x UInt32) TTL time_column + INTERVAL '1' MONTH DELETE WHERE column = 'value'" ) diff --git a/tests/dialects/test_databricks.py b/tests/dialects/test_databricks.py index 3df968b..7c03c83 100644 --- a/tests/dialects/test_databricks.py +++ b/tests/dialects/test_databricks.py @@ -32,6 +32,13 @@ class TestDatabricks(Validator): "CREATE TABLE foo (x INT GENERATED ALWAYS AS (YEAR(y)))", write={ "databricks": "CREATE TABLE foo (x INT GENERATED ALWAYS AS (YEAR(TO_DATE(y))))", + "tsql": "CREATE TABLE foo (x AS YEAR(CAST(y AS DATE)))", + }, + ) + self.validate_all( + "CREATE TABLE t1 AS (SELECT c FROM t2)", + read={ + "teradata": "CREATE TABLE t1 AS (SELECT c FROM t2) WITH DATA", }, ) diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py index 3e0ffd5..91eba17 100644 --- a/tests/dialects/test_dialect.py +++ b/tests/dialects/test_dialect.py @@ -5,6 +5,7 @@ from sqlglot import ( Dialects, ErrorLevel, ParseError, + TokenError, UnsupportedError, parse_one, ) @@ -308,6 +309,44 @@ class TestDialect(Validator): read={"postgres": "INET '127.0.0.1/32'"}, ) + def test_heredoc_strings(self): + for dialect in ("clickhouse", "postgres", "redshift"): + # Invalid matching tag + with self.assertRaises(TokenError): + parse_one("SELECT $tag1$invalid heredoc string$tag2$", dialect=dialect) + + # Unmatched tag + with self.assertRaises(TokenError): + parse_one("SELECT $tag1$invalid heredoc string", dialect=dialect) + + # Without tag + self.validate_all( + "SELECT 'this is a heredoc string'", + read={ + dialect: "SELECT $$this is a heredoc string$$", + }, + ) + self.validate_all( + "SELECT ''", + read={ + dialect: "SELECT $$$$", + }, + ) + + # With tag + self.validate_all( + "SELECT 'this is also a heredoc string'", + read={ + dialect: "SELECT $foo$this is also a heredoc string$foo$", + }, + ) + self.validate_all( + "SELECT ''", + read={ + dialect: "SELECT $foo$$foo$", + }, + ) + def test_decode(self): self.validate_identity("DECODE(bin, charset)") @@ -568,6 +607,7 @@ class TestDialect(Validator): "presto": "CAST(CAST(x AS TIMESTAMP) AS DATE)", "snowflake": "CAST(x AS DATE)", "doris": "TO_DATE(x)", + "mysql": "DATE(x)", }, ) self.validate_all( @@ -648,9 +688,7 @@ class TestDialect(Validator): self.validate_all( "DATE_ADD(x, 1, 'DAY')", read={ - "mysql": "DATE_ADD(x, INTERVAL 1 DAY)", "snowflake": "DATEADD('DAY', 1, x)", - "starrocks": "DATE_ADD(x, INTERVAL 1 DAY)", }, write={ "bigquery": "DATE_ADD(x, INTERVAL 1 DAY)", @@ -842,6 +880,7 @@ class TestDialect(Validator): "hive": "DATE_ADD('2021-02-01', 1)", "presto": "DATE_ADD('DAY', 1, CAST(CAST('2021-02-01' AS TIMESTAMP) AS DATE))", "spark": "DATE_ADD('2021-02-01', 1)", + "mysql": "DATE_ADD('2021-02-01', INTERVAL 1 DAY)", }, ) self.validate_all( @@ -897,10 +936,7 @@ class TestDialect(Validator): "bigquery", "drill", "duckdb", - "mysql", "presto", - "starrocks", - "doris", ) }, write={ @@ -913,8 +949,25 @@ class TestDialect(Validator): "presto", "hive", "spark", + ) + }, + ) + self.validate_all( + f"{unit}(TS_OR_DS_TO_DATE(x))", + read={ + dialect: f"{unit}(x)" + for dialect in ( + "mysql", + "doris", "starrocks", + ) + }, + write={ + dialect: f"{unit}(x)" + for dialect in ( + "mysql", "doris", + "starrocks", ) }, ) @@ -1790,3 +1843,17 @@ SELECT with self.assertRaises(ParseError): parse_one("CAST(x AS some_udt)", read="bigquery") + + def test_qualify(self): + self.validate_all( + "SELECT * FROM t QUALIFY COUNT(*) OVER () > 1", + write={ + "duckdb": "SELECT * FROM t QUALIFY COUNT(*) OVER () > 1", + "snowflake": "SELECT * FROM t QUALIFY COUNT(*) OVER () > 1", + "clickhouse": "SELECT * FROM (SELECT *, COUNT(*) OVER () AS _w FROM t) AS _t WHERE _w > 1", + "mysql": "SELECT * FROM (SELECT *, COUNT(*) OVER () AS _w FROM t) AS _t WHERE _w > 1", + "oracle": "SELECT * FROM (SELECT *, COUNT(*) OVER () AS _w FROM t) _t WHERE _w > 1", + "postgres": "SELECT * FROM (SELECT *, COUNT(*) OVER () AS _w FROM t) AS _t WHERE _w > 1", + "tsql": "SELECT * FROM (SELECT *, COUNT(*) OVER () AS _w FROM t) AS _t WHERE _w > 1", + }, + ) diff --git a/tests/dialects/test_duckdb.py b/tests/dialects/test_duckdb.py index dbf0a87..240f6f9 100644 --- a/tests/dialects/test_duckdb.py +++ b/tests/dialects/test_duckdb.py @@ -10,6 +10,10 @@ class TestDuckDB(Validator): parse_one("select * from t limit (select 5)").sql(dialect="duckdb"), exp.select("*").from_("t").limit(exp.select("5").subquery()).sql(dialect="duckdb"), ) + self.assertEqual( + parse_one("select * from t offset (select 5)").sql(dialect="duckdb"), + exp.select("*").from_("t").offset(exp.select("5").subquery()).sql(dialect="duckdb"), + ) for struct_value in ("{'a': 1}", "struct_pack(a := 1)"): self.validate_all(struct_value, write={"presto": UnsupportedError}) @@ -287,6 +291,8 @@ class TestDuckDB(Validator): "duckdb": "STRUCT_EXTRACT(x, 'abc')", "presto": "x.abc", "hive": "x.abc", + "postgres": "x.abc", + "redshift": "x.abc", "spark": "x.abc", }, ) @@ -446,6 +452,7 @@ class TestDuckDB(Validator): write={ "duckdb": "SELECT QUANTILE_CONT(x, q) FROM t", "postgres": "SELECT PERCENTILE_CONT(q) WITHIN GROUP (ORDER BY x) FROM t", + "snowflake": "SELECT PERCENTILE_CONT(q) WITHIN GROUP (ORDER BY x) FROM t", }, ) self.validate_all( @@ -453,6 +460,7 @@ class TestDuckDB(Validator): write={ "duckdb": "SELECT QUANTILE_DISC(x, q) FROM t", "postgres": "SELECT PERCENTILE_DISC(q) WITHIN GROUP (ORDER BY x) FROM t", + "snowflake": "SELECT PERCENTILE_DISC(q) WITHIN GROUP (ORDER BY x) FROM t", }, ) self.validate_all( @@ -460,6 +468,7 @@ class TestDuckDB(Validator): write={ "duckdb": "SELECT QUANTILE_CONT(x, 0.5) FROM t", "postgres": "SELECT PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY x) FROM t", + "snowflake": "SELECT PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY x) FROM t", }, ) diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py index 20f872c..11f921c 100644 --- a/tests/dialects/test_mysql.py +++ b/tests/dialects/test_mysql.py @@ -12,6 +12,8 @@ class TestMySQL(Validator): self.validate_identity(f"CREATE TABLE t (id {t} UNSIGNED)") self.validate_identity(f"CREATE TABLE t (id {t}(10) UNSIGNED)") + self.validate_identity("CREATE TABLE t (id DECIMAL(20, 4) UNSIGNED)") + self.validate_all( "CREATE TABLE t (id INT UNSIGNED)", write={ @@ -205,6 +207,9 @@ class TestMySQL(Validator): ) self.validate_identity("INTERVAL '1' YEAR") self.validate_identity("DATE_ADD(x, INTERVAL 1 YEAR)") + self.validate_identity("CHAR(0)") + self.validate_identity("CHAR(77, 121, 83, 81, '76')") + self.validate_identity("CHAR(77, 77.3, '77.3' USING utf8mb4)") def test_types(self): self.validate_identity("CAST(x AS MEDIUMINT) + CAST(y AS YEAR(4))") @@ -244,6 +249,13 @@ class TestMySQL(Validator): self.validate_identity( "SELECT WEEK_OF_YEAR('2023-01-01')", "SELECT WEEKOFYEAR('2023-01-01')" ) + self.validate_all( + "CHAR(10)", + write={ + "mysql": "CHAR(10)", + "presto": "CHR(10)", + }, + ) def test_escape(self): self.validate_identity("""'"abc"'""") @@ -496,6 +508,56 @@ class TestMySQL(Validator): self.validate_identity("FROM_UNIXTIME(a, b)") self.validate_identity("FROM_UNIXTIME(a, b, c)") self.validate_identity("TIME_STR_TO_UNIX(x)", "UNIX_TIMESTAMP(x)") + self.validate_all( + "SELECT TO_DAYS(x)", + write={ + "mysql": "SELECT (DATEDIFF(x, '0000-01-01') + 1)", + "presto": "SELECT (DATE_DIFF('DAY', CAST(CAST('0000-01-01' AS TIMESTAMP) AS DATE), CAST(CAST(x AS TIMESTAMP) AS DATE)) + 1)", + }, + ) + self.validate_all( + "SELECT DATEDIFF(x, y)", + write={"mysql": "SELECT DATEDIFF(x, y)", "presto": "SELECT DATE_DIFF('day', y, x)"}, + ) + self.validate_all( + "DAYOFYEAR(x)", + write={ + "mysql": "DAYOFYEAR(x)", + "": "DAY_OF_YEAR(TS_OR_DS_TO_DATE(x))", + }, + ) + self.validate_all( + "DAYOFMONTH(x)", + write={"mysql": "DAYOFMONTH(x)", "": "DAY_OF_MONTH(TS_OR_DS_TO_DATE(x))"}, + ) + self.validate_all( + "DAYOFWEEK(x)", + write={"mysql": "DAYOFWEEK(x)", "": "DAY_OF_WEEK(TS_OR_DS_TO_DATE(x))"}, + ) + self.validate_all( + "WEEKOFYEAR(x)", + write={"mysql": "WEEKOFYEAR(x)", "": "WEEK_OF_YEAR(TS_OR_DS_TO_DATE(x))"}, + ) + self.validate_all( + "DAY(x)", + write={"mysql": "DAY(x)", "": "DAY(TS_OR_DS_TO_DATE(x))"}, + ) + self.validate_all( + "WEEK(x)", + write={"mysql": "WEEK(x)", "": "WEEK(TS_OR_DS_TO_DATE(x))"}, + ) + self.validate_all( + "YEAR(x)", + write={"mysql": "YEAR(x)", "": "YEAR(TS_OR_DS_TO_DATE(x))"}, + ) + self.validate_all( + "DATE(x)", + read={"": "TS_OR_DS_TO_DATE(x)"}, + ) + self.validate_all( + "STR_TO_DATE(x, '%M')", + read={"": "TS_OR_DS_TO_DATE(x, '%B')"}, + ) def test_mysql(self): self.validate_all( @@ -896,7 +958,7 @@ COMMENT='客户账户表'""" self.validate_all( "MONTHNAME(x)", write={ - "": "TIME_TO_STR(x, '%B')", + "": "TIME_TO_STR(TS_OR_DS_TO_DATE(x), '%B')", "mysql": "DATE_FORMAT(x, '%M')", }, ) diff --git a/tests/dialects/test_oracle.py b/tests/dialects/test_oracle.py index 675ee8a..5572ec1 100644 --- a/tests/dialects/test_oracle.py +++ b/tests/dialects/test_oracle.py @@ -22,8 +22,6 @@ class TestOracle(Validator): self.validate_identity("SELECT * FROM t FOR UPDATE OF s.t.c, s.t.v SKIP LOCKED") self.validate_identity("SELECT STANDARD_HASH('hello')") self.validate_identity("SELECT STANDARD_HASH('hello', 'MD5')") - self.validate_identity("SELECT CAST(NULL AS VARCHAR2(2328 CHAR)) AS COL1") - self.validate_identity("SELECT CAST(NULL AS VARCHAR2(2328 BYTE)) AS COL1") self.validate_identity("SELECT * FROM table_name@dblink_name.database_link_domain") self.validate_identity("SELECT * FROM table_name SAMPLE (25) s") self.validate_identity("SELECT * FROM V$SESSION") @@ -61,6 +59,20 @@ class TestOracle(Validator): ) self.validate_all( + "SELECT CAST(NULL AS VARCHAR2(2328 CHAR)) AS COL1", + write={ + "oracle": "SELECT CAST(NULL AS VARCHAR2(2328 CHAR)) AS COL1", + "spark": "SELECT CAST(NULL AS VARCHAR(2328)) AS COL1", + }, + ) + self.validate_all( + "SELECT CAST(NULL AS VARCHAR2(2328 BYTE)) AS COL1", + write={ + "oracle": "SELECT CAST(NULL AS VARCHAR2(2328 BYTE)) AS COL1", + "spark": "SELECT CAST(NULL AS VARCHAR(2328)) AS COL1", + }, + ) + self.validate_all( "NVL(NULL, 1)", write={ "": "COALESCE(NULL, 1)", diff --git a/tests/dialects/test_postgres.py b/tests/dialects/test_postgres.py index 6a3df47..0ddc106 100644 --- a/tests/dialects/test_postgres.py +++ b/tests/dialects/test_postgres.py @@ -10,6 +10,9 @@ class TestPostgres(Validator): def test_ddl(self): self.validate_identity( + "CREATE INDEX foo ON bar.baz USING btree(col1 varchar_pattern_ops ASC, col2)" + ) + self.validate_identity( "CREATE TABLE test (x TIMESTAMP WITHOUT TIME ZONE[][])", "CREATE TABLE test (x TIMESTAMP[][])", ) @@ -149,15 +152,27 @@ class TestPostgres(Validator): ) def test_postgres(self): - expr = parse_one("SELECT * FROM r CROSS JOIN LATERAL UNNEST(ARRAY[1]) AS s(location)") + expr = parse_one( + "SELECT * FROM r CROSS JOIN LATERAL UNNEST(ARRAY[1]) AS s(location)", read="postgres" + ) unnest = expr.args["joins"][0].this.this unnest.assert_is(exp.Unnest) alter_table_only = """ALTER TABLE ONLY "Album" ADD CONSTRAINT "FK_AlbumArtistId" FOREIGN KEY ("ArtistId") REFERENCES "Artist" ("ArtistId") ON DELETE NO ACTION ON UPDATE NO ACTION""" - expr = parse_one(alter_table_only) + expr = parse_one(alter_table_only, read="postgres") # Checks that user-defined types are parsed into DataType instead of Identifier - parse_one("CREATE TABLE t (a udt)").this.expressions[0].args["kind"].assert_is(exp.DataType) + parse_one("CREATE TABLE t (a udt)", read="postgres").this.expressions[0].args[ + "kind" + ].assert_is(exp.DataType) + + # Checks that OID is parsed into a DataType (ObjectIdentifier) + self.assertIsInstance( + parse_one("CREATE TABLE public.propertydata (propertyvalue oid)", read="postgres").find( + exp.DataType + ), + exp.ObjectIdentifier, + ) self.assertIsInstance(expr, exp.AlterTable) self.assertEqual(expr.sql(dialect="postgres"), alter_table_only) @@ -192,7 +207,6 @@ class TestPostgres(Validator): self.validate_identity("SELECT ARRAY[1, 2, 3] @> ARRAY[1, 2]") self.validate_identity("SELECT ARRAY[1, 2, 3] <@ ARRAY[1, 2]") self.validate_identity("SELECT ARRAY[1, 2, 3] && ARRAY[1, 2]") - self.validate_identity("$x") self.validate_identity("x$") self.validate_identity("SELECT ARRAY[1, 2, 3]") self.validate_identity("SELECT ARRAY(SELECT 1)") diff --git a/tests/dialects/test_presto.py b/tests/dialects/test_presto.py index a80013e..8edd31c 100644 --- a/tests/dialects/test_presto.py +++ b/tests/dialects/test_presto.py @@ -300,7 +300,6 @@ class TestPresto(Validator): write={ "presto": "DATE_ADD('DAY', 1 * -1, x)", }, - read={"mysql": "DATE_SUB(x, INTERVAL 1 DAY)"}, ) self.validate_all( "NOW()", @@ -503,6 +502,7 @@ class TestPresto(Validator): @mock.patch("sqlglot.helper.logger") def test_presto(self, logger): + self.validate_identity("string_agg(x, ',')", "ARRAY_JOIN(ARRAY_AGG(x), ',')") self.validate_identity( "SELECT * FROM example.testdb.customer_orders FOR VERSION AS OF 8954597067493422955" ) diff --git a/tests/dialects/test_redshift.py b/tests/dialects/test_redshift.py index c75654c..ae1b987 100644 --- a/tests/dialects/test_redshift.py +++ b/tests/dialects/test_redshift.py @@ -6,6 +6,11 @@ class TestRedshift(Validator): dialect = "redshift" def test_redshift(self): + self.validate_identity( + "SELECT 'a''b'", + "SELECT 'a\\'b'", + ) + self.validate_all( "x ~* 'pat'", write={ @@ -226,7 +231,6 @@ class TestRedshift(Validator): self.validate_identity("SELECT * FROM #x") self.validate_identity("SELECT INTERVAL '5 day'") self.validate_identity("foo$") - self.validate_identity("$foo") self.validate_identity("CAST('bla' AS SUPER)") self.validate_identity("CREATE TABLE real1 (realcol REAL)") self.validate_identity("CAST('foo' AS HLLSKETCH)") diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py index a217394..7c36bea 100644 --- a/tests/dialects/test_snowflake.py +++ b/tests/dialects/test_snowflake.py @@ -1,6 +1,7 @@ from unittest import mock from sqlglot import UnsupportedError, exp, parse_one +from sqlglot.optimizer.normalize_identifiers import normalize_identifiers from tests.dialects.test_dialect import Validator @@ -8,34 +9,6 @@ class TestSnowflake(Validator): dialect = "snowflake" def test_snowflake(self): - self.validate_identity( - 'DESCRIBE TABLE "SNOWFLAKE_SAMPLE_DATA"."TPCDS_SF100TCL"."WEB_SITE" type=stage' - ) - - self.validate_all( - "SELECT * FROM x START WITH a = b CONNECT BY c = PRIOR d", - read={ - "oracle": "SELECT * FROM x START WITH a = b CONNECT BY c = PRIOR d", - }, - write={ - "oracle": "SELECT * FROM x START WITH a = b CONNECT BY c = PRIOR d", - "snowflake": "SELECT * FROM x START WITH a = b CONNECT BY c = PRIOR d", - }, - ) - self.validate_all( - "SELECT INSERT(a, 0, 0, 'b')", - read={ - "mysql": "SELECT INSERT(a, 0, 0, 'b')", - "snowflake": "SELECT INSERT(a, 0, 0, 'b')", - "tsql": "SELECT STUFF(a, 0, 0, 'b')", - }, - write={ - "mysql": "SELECT INSERT(a, 0, 0, 'b')", - "snowflake": "SELECT INSERT(a, 0, 0, 'b')", - "tsql": "SELECT STUFF(a, 0, 0, 'b')", - }, - ) - self.validate_identity("LISTAGG(data['some_field'], ',')") self.validate_identity("WEEKOFYEAR(tstamp)") self.validate_identity("SELECT SUM(amount) FROM mytable GROUP BY ALL") @@ -54,7 +27,6 @@ class TestSnowflake(Validator): self.validate_identity("$x") # parameter self.validate_identity("a$b") # valid snowflake identifier self.validate_identity("SELECT REGEXP_LIKE(a, b, c)") - self.validate_identity("PUT file:///dir/tmp.csv @%table") self.validate_identity("CREATE TABLE foo (bar FLOAT AUTOINCREMENT START 0 INCREMENT 1)") self.validate_identity("ALTER TABLE IF EXISTS foo SET TAG a = 'a', b = 'b', c = 'c'") self.validate_identity("ALTER TABLE foo UNSET TAG a, b, c") @@ -65,12 +37,16 @@ class TestSnowflake(Validator): self.validate_identity("SELECT CONVERT_TIMEZONE('UTC', 'America/Los_Angeles', col)") self.validate_identity("REGEXP_REPLACE('target', 'pattern', '\n')") self.validate_identity( - 'COPY INTO NEW_TABLE ("foo", "bar") FROM (SELECT $1, $2, $3, $4 FROM @%old_table)' + 'DESCRIBE TABLE "SNOWFLAKE_SAMPLE_DATA"."TPCDS_SF100TCL"."WEB_SITE" type=stage' ) self.validate_identity( "SELECT state, city, SUM(retail_price * quantity) AS gross_revenue FROM sales GROUP BY ALL" ) self.validate_identity( + "SELECT * FROM foo window", + "SELECT * FROM foo AS window", + ) + self.validate_identity( r"SELECT RLIKE(a, $$regular expression with \ characters: \d{2}-\d{3}-\d{4}$$, 'i') FROM log_source", r"SELECT REGEXP_LIKE(a, 'regular expression with \\ characters: \\d{2}-\\d{3}-\\d{4}', 'i') FROM log_source", ) @@ -88,6 +64,36 @@ class TestSnowflake(Validator): self.validate_all("CAST(x AS CHARACTER VARYING)", write={"snowflake": "CAST(x AS VARCHAR)"}) self.validate_all("CAST(x AS NCHAR VARYING)", write={"snowflake": "CAST(x AS VARCHAR)"}) self.validate_all( + "SELECT COLLATE('B', 'und:ci')", + write={ + "bigquery": "SELECT COLLATE('B', 'und:ci')", + "snowflake": "SELECT COLLATE('B', 'und:ci')", + }, + ) + self.validate_all( + "SELECT * FROM x START WITH a = b CONNECT BY c = PRIOR d", + read={ + "oracle": "SELECT * FROM x START WITH a = b CONNECT BY c = PRIOR d", + }, + write={ + "oracle": "SELECT * FROM x START WITH a = b CONNECT BY c = PRIOR d", + "snowflake": "SELECT * FROM x START WITH a = b CONNECT BY c = PRIOR d", + }, + ) + self.validate_all( + "SELECT INSERT(a, 0, 0, 'b')", + read={ + "mysql": "SELECT INSERT(a, 0, 0, 'b')", + "snowflake": "SELECT INSERT(a, 0, 0, 'b')", + "tsql": "SELECT STUFF(a, 0, 0, 'b')", + }, + write={ + "mysql": "SELECT INSERT(a, 0, 0, 'b')", + "snowflake": "SELECT INSERT(a, 0, 0, 'b')", + "tsql": "SELECT STUFF(a, 0, 0, 'b')", + }, + ) + self.validate_all( "ARRAY_GENERATE_RANGE(0, 3)", write={ "bigquery": "GENERATE_ARRAY(0, 3 - 1)", @@ -513,6 +519,40 @@ class TestSnowflake(Validator): }, ) + def test_staged_files(self): + # Ensure we don't treat staged file paths as identifiers (i.e. they're not normalized) + staged_file = parse_one("SELECT * FROM @foo", read="snowflake") + self.assertEqual( + normalize_identifiers(staged_file, dialect="snowflake").sql(dialect="snowflake"), + staged_file.sql(dialect="snowflake"), + ) + + self.validate_identity("SELECT * FROM @~") + self.validate_identity("SELECT * FROM @~/some/path/to/file.csv") + self.validate_identity("SELECT * FROM @mystage") + self.validate_identity("SELECT * FROM '@mystage'") + self.validate_identity("SELECT * FROM @namespace.mystage/path/to/file.json.gz") + self.validate_identity("SELECT * FROM @namespace.%table_name/path/to/file.json.gz") + self.validate_identity("SELECT * FROM '@external/location' (FILE_FORMAT => 'path.to.csv')") + self.validate_identity("PUT file:///dir/tmp.csv @%table") + self.validate_identity( + 'COPY INTO NEW_TABLE ("foo", "bar") FROM (SELECT $1, $2, $3, $4 FROM @%old_table)' + ) + self.validate_identity( + "SELECT * FROM @foo/bar (FILE_FORMAT => ds_sandbox.test.my_csv_format, PATTERN => 'test') AS bla" + ) + self.validate_identity( + "SELECT t.$1, t.$2 FROM @mystage1 (FILE_FORMAT => 'myformat', PATTERN => '.*data.*[.]csv.gz') AS t" + ) + self.validate_identity( + "SELECT parse_json($1):a.b FROM @mystage2/data1.json.gz", + "SELECT PARSE_JSON($1)['a'].b FROM @mystage2/data1.json.gz", + ) + self.validate_identity( + "SELECT * FROM @mystage t (c1)", + "SELECT * FROM @mystage AS t(c1)", + ) + def test_sample(self): self.validate_identity("SELECT * FROM testtable TABLESAMPLE BERNOULLI (20.3)") self.validate_identity("SELECT * FROM testtable TABLESAMPLE (100)") @@ -660,7 +700,6 @@ class TestSnowflake(Validator): self.validate_identity("CREATE MATERIALIZED VIEW a COMMENT='...' AS SELECT 1 FROM x") self.validate_identity("CREATE DATABASE mytestdb_clone CLONE mytestdb") self.validate_identity("CREATE SCHEMA mytestschema_clone CLONE testschema") - self.validate_identity("CREATE TABLE orders_clone CLONE orders") self.validate_identity("CREATE TABLE IDENTIFIER('foo') (COLUMN1 VARCHAR, COLUMN2 VARCHAR)") self.validate_identity("CREATE TABLE IDENTIFIER($foo) (col1 VARCHAR, col2 VARCHAR)") self.validate_identity( @@ -680,6 +719,16 @@ class TestSnowflake(Validator): ) self.validate_all( + "CREATE TABLE orders_clone CLONE orders", + read={ + "bigquery": "CREATE TABLE orders_clone CLONE orders", + }, + write={ + "bigquery": "CREATE TABLE orders_clone CLONE orders", + "snowflake": "CREATE TABLE orders_clone CLONE orders", + }, + ) + self.validate_all( "CREATE OR REPLACE TRANSIENT TABLE a (id INT)", read={ "postgres": "CREATE OR REPLACE TRANSIENT TABLE a (id INT)", diff --git a/tests/dialects/test_spark.py b/tests/dialects/test_spark.py index 2e43ba5..0148e55 100644 --- a/tests/dialects/test_spark.py +++ b/tests/dialects/test_spark.py @@ -8,6 +8,7 @@ class TestSpark(Validator): dialect = "spark" def test_ddl(self): + self.validate_identity("CREATE TEMPORARY VIEW test AS SELECT 1") self.validate_identity("CREATE TABLE foo (col VARCHAR(50))") self.validate_identity("CREATE TABLE foo (col STRUCT<struct_col_a: VARCHAR((50))>)") self.validate_identity("CREATE TABLE foo (col STRING) CLUSTERED BY (col) INTO 10 BUCKETS") diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py index f76894d..7d89d06 100644 --- a/tests/dialects/test_tsql.py +++ b/tests/dialects/test_tsql.py @@ -7,6 +7,45 @@ class TestTSQL(Validator): def test_tsql(self): self.validate_all( + "CREATE TABLE #mytemptable (a INTEGER)", + read={ + "duckdb": "CREATE TEMPORARY TABLE mytemptable (a INT)", + }, + write={ + "tsql": "CREATE TABLE #mytemptable (a INTEGER)", + "snowflake": "CREATE TEMPORARY TABLE mytemptable (a INT)", + "duckdb": "CREATE TEMPORARY TABLE mytemptable (a INT)", + "oracle": "CREATE TEMPORARY TABLE mytemptable (a NUMBER)", + "hive": "CREATE TEMPORARY TABLE mytemptable (a INT)", + "spark2": "CREATE TEMPORARY TABLE mytemptable (a INT) USING PARQUET", + "spark": "CREATE TEMPORARY TABLE mytemptable (a INT) USING PARQUET", + "databricks": "CREATE TEMPORARY TABLE mytemptable (a INT) USING PARQUET", + }, + ) + self.validate_all( + "CREATE TABLE #mytemp (a INTEGER, b CHAR(2), c TIME(4), d FLOAT(24))", + write={ + "spark": "CREATE TEMPORARY TABLE mytemp (a INT, b CHAR(2), c TIMESTAMP, d FLOAT) USING PARQUET", + "tsql": "CREATE TABLE #mytemp (a INTEGER, b CHAR(2), c TIME(4), d FLOAT(24))", + }, + ) + self.validate_all( + """CREATE TABLE [dbo].[mytable]( + [email] [varchar](255) NOT NULL, + CONSTRAINT [UN_t_mytable] UNIQUE NONCLUSTERED + ( + [email] ASC + ) + )""", + write={ + "hive": "CREATE TABLE `dbo`.`mytable` (`email` VARCHAR(255) NOT NULL)", + "spark2": "CREATE TABLE `dbo`.`mytable` (`email` VARCHAR(255) NOT NULL)", + "spark": "CREATE TABLE `dbo`.`mytable` (`email` VARCHAR(255) NOT NULL)", + "databricks": "CREATE TABLE `dbo`.`mytable` (`email` VARCHAR(255) NOT NULL)", + }, + ) + + self.validate_all( "CREATE TABLE x ( A INTEGER NOT NULL, B INTEGER NULL )", write={ "tsql": "CREATE TABLE x (A INTEGER NOT NULL, B INTEGER NULL)", @@ -492,6 +531,10 @@ class TestTSQL(Validator): ) def test_ddl(self): + self.validate_identity( + "CREATE PROCEDURE foo AS BEGIN DELETE FROM bla WHERE foo < CURRENT_TIMESTAMP - 7 END", + "CREATE PROCEDURE foo AS BEGIN DELETE FROM bla WHERE foo < GETDATE() - 7 END", + ) self.validate_all( "CREATE TABLE tbl (id INTEGER IDENTITY PRIMARY KEY)", read={ @@ -505,6 +548,9 @@ class TestTSQL(Validator): "postgres": "CREATE TABLE tbl (id INT NOT NULL GENERATED ALWAYS AS IDENTITY (START WITH 10) PRIMARY KEY)", "tsql": "CREATE TABLE tbl (id INTEGER NOT NULL IDENTITY(10, 1) PRIMARY KEY)", }, + write={ + "databricks": "CREATE TABLE tbl (id BIGINT NOT NULL GENERATED ALWAYS AS IDENTITY (START WITH 10 INCREMENT BY 1) PRIMARY KEY)", + }, ) self.validate_all( "SELECT * INTO foo.bar.baz FROM (SELECT * FROM a.b.c) AS temp", @@ -561,22 +607,10 @@ class TestTSQL(Validator): self.validate_all( "CREATE TABLE #mytemp (a INTEGER, b CHAR(2), c TIME(4), d FLOAT(24))", write={ - "spark": "CREATE TEMPORARY TABLE mytemp (a INT, b CHAR(2), c TIMESTAMP, d FLOAT)", + "spark": "CREATE TEMPORARY TABLE mytemp (a INT, b CHAR(2), c TIMESTAMP, d FLOAT) USING PARQUET", "tsql": "CREATE TABLE #mytemp (a INTEGER, b CHAR(2), c TIME(4), d FLOAT(24))", }, ) - self.validate_all( - "CREATE TABLE #mytemptable (a INTEGER)", - read={ - "duckdb": "CREATE TEMPORARY TABLE mytemptable (a INT)", - }, - write={ - "tsql": "CREATE TABLE #mytemptable (a INTEGER)", - "snowflake": "CREATE TEMPORARY TABLE mytemptable (a INT)", - "duckdb": "CREATE TEMPORARY TABLE mytemptable (a INT)", - "oracle": "CREATE TEMPORARY TABLE mytemptable (a NUMBER)", - }, - ) def test_insert_cte(self): self.validate_all( diff --git a/tests/fixtures/identity.sql b/tests/fixtures/identity.sql index 17506e4..2738707 100644 --- a/tests/fixtures/identity.sql +++ b/tests/fixtures/identity.sql @@ -771,8 +771,8 @@ ALTER TABLE integers DROP COLUMN k ALTER TABLE integers DROP PRIMARY KEY ALTER TABLE integers DROP COLUMN IF EXISTS k ALTER TABLE integers DROP COLUMN k CASCADE -ALTER TABLE integers ALTER COLUMN i TYPE VARCHAR -ALTER TABLE integers ALTER COLUMN i TYPE VARCHAR USING CONCAT(i, '_', j) +ALTER TABLE integers ALTER COLUMN i SET DATA TYPE VARCHAR +ALTER TABLE integers ALTER COLUMN i SET DATA TYPE VARCHAR USING CONCAT(i, '_', j) ALTER TABLE integers ALTER COLUMN i SET DEFAULT 10 ALTER TABLE integers ALTER COLUMN i DROP DEFAULT ALTER TABLE mydataset.mytable DROP COLUMN A, DROP COLUMN IF EXISTS B @@ -864,3 +864,5 @@ SELECT x FROM y ORDER BY x ASC KILL '123' KILL CONNECTION 123 KILL QUERY '123' +CHR(97) +SELECT * FROM UNNEST(x) WITH ORDINALITY UNION ALL SELECT * FROM UNNEST(y) WITH ORDINALITY diff --git a/tests/fixtures/optimizer/canonicalize.sql b/tests/fixtures/optimizer/canonicalize.sql index e27b2d3..2ba762d 100644 --- a/tests/fixtures/optimizer/canonicalize.sql +++ b/tests/fixtures/optimizer/canonicalize.sql @@ -29,6 +29,12 @@ SELECT "x"."a" AS "a" FROM "x" AS "x" GROUP BY "x"."a" HAVING SUM("x"."b") <> 0 SELECT a FROM x WHERE 1; SELECT "x"."a" AS "a" FROM "x" AS "x" WHERE 1 <> 0; +SELECT a FROM x WHERE COALESCE(0, 1); +SELECT "x"."a" AS "a" FROM "x" AS "x" WHERE COALESCE(0 <> 0, 1 <> 0); + +SELECT a FROM x WHERE CASE WHEN COALESCE(b, 1) THEN 1 ELSE 0 END; +SELECT "x"."a" AS "a" FROM "x" AS "x" WHERE CASE WHEN COALESCE("x"."b" <> 0, 1 <> 0) THEN 1 ELSE 0 END <> 0; + -------------------------------------- -- Replace date functions -------------------------------------- @@ -40,3 +46,9 @@ CAST('2023-01-01' AS TIMESTAMP); TIMESTAMP('2023-01-01', '12:00:00'); TIMESTAMP('2023-01-01', '12:00:00'); + +DATE_ADD(CAST("x" AS DATE), 1, 'YEAR'); +DATE_ADD(CAST("x" AS DATE), 1, 'YEAR'); + +DATE_ADD('2023-01-01', 1, 'YEAR'); +DATE_ADD(CAST('2023-01-01' AS DATE), 1, 'YEAR'); diff --git a/tests/fixtures/optimizer/normalize_identifiers.sql b/tests/fixtures/optimizer/normalize_identifiers.sql index 2ab4778..4cb7dd1 100644 --- a/tests/fixtures/optimizer/normalize_identifiers.sql +++ b/tests/fixtures/optimizer/normalize_identifiers.sql @@ -62,3 +62,11 @@ SELECT a AS a FROM x UNION SELECT a AS a FROM x; (SELECT A AS A FROM X); (SELECT a AS a FROM x); + +# dialect: snowflake +SELECT a /* sqlglot.meta case_sensitive */, b FROM table /* sqlglot.meta case_sensitive */; +SELECT a /* sqlglot.meta case_sensitive */, B FROM table /* sqlglot.meta case_sensitive */; + +# dialect: redshift +SELECT COALESCE(json_val.a /* sqlglot.meta case_sensitive */, json_val.A /* sqlglot.meta case_sensitive */) FROM table; +SELECT COALESCE(json_val.a /* sqlglot.meta case_sensitive */, json_val.A /* sqlglot.meta case_sensitive */) FROM table; diff --git a/tests/fixtures/optimizer/optimizer.sql b/tests/fixtures/optimizer/optimizer.sql index e59f14d..4cc62c9 100644 --- a/tests/fixtures/optimizer/optimizer.sql +++ b/tests/fixtures/optimizer/optimizer.sql @@ -1023,3 +1023,25 @@ SELECT FROM "table1" AS "table1" LEFT JOIN "alias3" ON "table1"."cid" = "alias3"."cid"; + +# title: CTE with EXPLODE cannot be merged +# dialect: spark +# execute: false +SELECT Name, + FruitStruct.`$id`, + FruitStruct.value + FROM + (SELECT Name, + explode(Fruits) as FruitStruct + FROM fruits_table); +WITH `_q_0` AS ( + SELECT + `fruits_table`.`name` AS `name`, + EXPLODE(`fruits_table`.`fruits`) AS `fruitstruct` + FROM `fruits_table` AS `fruits_table` +) +SELECT + `_q_0`.`name` AS `name`, + `_q_0`.`fruitstruct`.`$id` AS `$id`, + `_q_0`.`fruitstruct`.`value` AS `value` +FROM `_q_0` AS `_q_0`; diff --git a/tests/fixtures/optimizer/simplify.sql b/tests/fixtures/optimizer/simplify.sql index 584e9d6..a9ae192 100644 --- a/tests/fixtures/optimizer/simplify.sql +++ b/tests/fixtures/optimizer/simplify.sql @@ -444,6 +444,9 @@ CAST('1998-09-02 00:00:00' AS DATETIME); CAST(x AS DATETIME) + interval '1' week; CAST(x AS DATETIME) + INTERVAL '1' week; +TS_OR_DS_TO_DATE('1998-12-01 00:00:01') - interval '90' day; +CAST('1998-09-02' AS DATE); + -------------------------------------- -- Comparisons -------------------------------------- @@ -681,6 +684,9 @@ CONCAT('a', x, y, 'bc'); 'a' || 'b' || x; CONCAT('ab', x); +CONCAT(a, b) IN (SELECT * FROM foo WHERE cond); +CONCAT(a, b) IN (SELECT * FROM foo WHERE cond); + -------------------------------------- -- DATE_TRUNC -------------------------------------- @@ -740,6 +746,9 @@ x >= CAST('2022-01-01' AS DATE); DATE_TRUNC('year', x) > CAST('2021-01-02' AS DATE); x >= CAST('2022-01-01' AS DATE); +DATE_TRUNC('year', x) > TS_OR_DS_TO_DATE(TS_OR_DS_TO_DATE('2021-01-02')); +x >= CAST('2022-01-01' AS DATE); + -- right is not a date DATE_TRUNC('year', x) <> '2021-01-02'; DATE_TRUNC('year', x) <> '2021-01-02'; @@ -758,6 +767,17 @@ x < CAST('2022-01-01' AS DATE) AND x >= CAST('2021-01-01' AS DATE); TIMESTAMP_TRUNC(x, YEAR) = CAST('2021-01-01' AS DATETIME); x < CAST('2022-01-01 00:00:00' AS DATETIME) AND x >= CAST('2021-01-01 00:00:00' AS DATETIME); +-- right side is not a date literal +DATE_TRUNC('day', x) = CAST(y AS DATE); +DATE_TRUNC('day', x) = CAST(y AS DATE); + +-- nested cast +DATE_TRUNC('day', x) = CAST(CAST('2021-01-01 01:02:03' AS DATETIME) AS DATE); +x < CAST('2021-01-02' AS DATE) AND x >= CAST('2021-01-01' AS DATE); + +TIMESTAMP_TRUNC(x, YEAR) = CAST(CAST('2021-01-01 01:02:03' AS DATE) AS DATETIME); +x < CAST('2022-01-01 00:00:00' AS DATETIME) AND x >= CAST('2021-01-01 00:00:00' AS DATETIME); + -------------------------------------- -- EQUALITY -------------------------------------- @@ -794,6 +814,9 @@ x = 2; x - INTERVAL 1 DAY = CAST('2021-01-01' AS DATE); x = CAST('2021-01-02' AS DATE); +x - INTERVAL 1 DAY = TS_OR_DS_TO_DATE('2021-01-01 00:00:01'); +x = CAST('2021-01-02' AS DATE); + x - INTERVAL 1 HOUR > CAST('2021-01-01' AS DATETIME); x > CAST('2021-01-01 01:00:00' AS DATETIME); diff --git a/tests/fixtures/optimizer/tpc-ds/tpc-ds.sql b/tests/fixtures/optimizer/tpc-ds/tpc-ds.sql index f50cf0b..2218182 100644 --- a/tests/fixtures/optimizer/tpc-ds/tpc-ds.sql +++ b/tests/fixtures/optimizer/tpc-ds/tpc-ds.sql @@ -4793,10 +4793,10 @@ WITH "foo" AS ( "foo"."i_item_sk" AS "i_item_sk", "foo"."d_moy" AS "d_moy", "foo"."mean" AS "mean", - CASE "foo"."mean" WHEN 0 THEN NULL ELSE "foo"."stdev" / "foo"."mean" END AS "cov" + CASE "foo"."mean" WHEN FALSE THEN NULL ELSE "foo"."stdev" / "foo"."mean" END AS "cov" FROM "foo" AS "foo" WHERE - CASE "foo"."mean" WHEN 0 THEN 0 ELSE "foo"."stdev" / "foo"."mean" END > 1 + CASE "foo"."mean" WHEN FALSE THEN 0 ELSE "foo"."stdev" / "foo"."mean" END > 1 ) SELECT "inv1"."w_warehouse_sk" AS "w_warehouse_sk", @@ -9775,7 +9775,7 @@ JOIN "date_dim" AS "d1" ON "catalog_sales"."cs_sold_date_sk" = "d1"."d_date_sk" AND "d1"."d_week_seq" = "d2"."d_week_seq" AND "d1"."d_year" = 2002 - AND "d3"."d_date" > CONCAT("d1"."d_date", INTERVAL '5' day) + AND "d3"."d_date" > "d1"."d_date" + INTERVAL '5' day GROUP BY "item"."i_item_desc", "warehouse"."w_warehouse_name", diff --git a/tests/test_executor.py b/tests/test_executor.py index ffe0229..c6b85c9 100644 --- a/tests/test_executor.py +++ b/tests/test_executor.py @@ -624,6 +624,8 @@ class TestExecutor(unittest.TestCase): ("LEFT('12345', 3)", "123"), ("RIGHT('12345', 3)", "345"), ("DATEDIFF('2022-01-03'::date, '2022-01-01'::TIMESTAMP::DATE)", 2), + ("TRIM(' foo ')", "foo"), + ("TRIM('afoob', 'ab')", "foo"), ]: with self.subTest(sql): result = execute(f"SELECT {sql}") diff --git a/tests/test_expressions.py b/tests/test_expressions.py index b1b5360..f8c8bcc 100644 --- a/tests/test_expressions.py +++ b/tests/test_expressions.py @@ -182,16 +182,21 @@ class TestExpressions(unittest.TestCase): self.assertEqual(parse_one("a.b.c").name, "c") def test_table_name(self): + bq_dashed_table = exp.to_table("a-1.b.c", dialect="bigquery") + self.assertEqual(exp.table_name(bq_dashed_table), '"a-1".b.c') + self.assertEqual(exp.table_name(bq_dashed_table, dialect="bigquery"), "`a-1`.b.c") + self.assertEqual(exp.table_name("a-1.b.c", dialect="bigquery"), "`a-1`.b.c") self.assertEqual(exp.table_name(parse_one("a", into=exp.Table)), "a") self.assertEqual(exp.table_name(parse_one("a.b", into=exp.Table)), "a.b") self.assertEqual(exp.table_name(parse_one("a.b.c", into=exp.Table)), "a.b.c") self.assertEqual(exp.table_name("a.b.c"), "a.b.c") + self.assertEqual(exp.table_name(exp.to_table("a.b.c.d.e", dialect="bigquery")), "a.b.c.d.e") + self.assertEqual(exp.table_name(exp.to_table("'@foo'", dialect="snowflake")), "'@foo'") + self.assertEqual(exp.table_name(exp.to_table("@foo", dialect="snowflake")), "@foo") self.assertEqual( exp.table_name(parse_one("foo.`{bar,er}`", read="databricks"), dialect="databricks"), "foo.`{bar,er}`", ) - self.assertEqual(exp.table_name(exp.to_table("a-1.b.c", dialect="bigquery")), '"a-1".b.c') - self.assertEqual(exp.table_name(exp.to_table("a.b.c.d.e", dialect="bigquery")), "a.b.c.d.e") def test_table(self): self.assertEqual(exp.table_("a", alias="b"), parse_one("select * from a b").find(exp.Table)) @@ -946,3 +951,8 @@ FROM foo""", with self.assertRaises(ParseError): exp.DataType.build("foo") + + def test_set_meta(self): + query = parse_one("SELECT * FROM foo /* sqlglot.meta x = 1, y = a, z */") + self.assertEqual(query.find(exp.Table).meta, {"x": "1", "y": "a", "z": True}) + self.assertEqual(query.sql(), "SELECT * FROM foo /* sqlglot.meta x = 1, y = a, z */") diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index 8775852..8fc3273 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -546,6 +546,53 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') self.assertEqual(expression.right.this.left.type.this, exp.DataType.Type.INT) self.assertEqual(expression.right.this.right.type.this, exp.DataType.Type.INT) + def test_interval_math_annotation(self): + schema = { + "x": { + "a": "DATE", + "b": "DATETIME", + } + } + for sql, expected_type, *expected_sql in [ + ( + "SELECT '2023-01-01' + INTERVAL '1' DAY", + exp.DataType.Type.DATE, + "SELECT CAST('2023-01-01' AS DATE) + INTERVAL '1' DAY", + ), + ( + "SELECT '2023-01-01' + INTERVAL '1' HOUR", + exp.DataType.Type.DATETIME, + "SELECT CAST('2023-01-01' AS DATETIME) + INTERVAL '1' HOUR", + ), + ( + "SELECT '2023-01-01 00:00:01' + INTERVAL '1' HOUR", + exp.DataType.Type.DATETIME, + "SELECT CAST('2023-01-01 00:00:01' AS DATETIME) + INTERVAL '1' HOUR", + ), + ("SELECT 'nonsense' + INTERVAL '1' DAY", exp.DataType.Type.UNKNOWN), + ("SELECT x.a + INTERVAL '1' DAY FROM x AS x", exp.DataType.Type.DATE), + ("SELECT x.a + INTERVAL '1' HOUR FROM x AS x", exp.DataType.Type.DATETIME), + ("SELECT x.b + INTERVAL '1' DAY FROM x AS x", exp.DataType.Type.DATETIME), + ("SELECT x.b + INTERVAL '1' HOUR FROM x AS x", exp.DataType.Type.DATETIME), + ( + "SELECT DATE_ADD('2023-01-01', 1, 'DAY')", + exp.DataType.Type.DATE, + "SELECT DATE_ADD(CAST('2023-01-01' AS DATE), 1, 'DAY')", + ), + ( + "SELECT DATE_ADD('2023-01-01 00:00:00', 1, 'DAY')", + exp.DataType.Type.DATETIME, + "SELECT DATE_ADD(CAST('2023-01-01 00:00:00' AS DATETIME), 1, 'DAY')", + ), + ("SELECT DATE_ADD(x.a, 1, 'DAY') FROM x AS x", exp.DataType.Type.DATE), + ("SELECT DATE_ADD(x.a, 1, 'HOUR') FROM x AS x", exp.DataType.Type.DATETIME), + ("SELECT DATE_ADD(x.b, 1, 'DAY') FROM x AS x", exp.DataType.Type.DATETIME), + ]: + with self.subTest(sql): + expression = annotate_types(parse_one(sql), schema=schema) + self.assertEqual(expected_type, expression.expressions[0].type.this) + self.assertEqual(expected_sql[0] if expected_sql else sql, expression.sql()) + def test_lateral_annotation(self): expression = optimizer.optimize( parse_one("SELECT c FROM (select 1 a) as x LATERAL VIEW EXPLODE (a) AS c") diff --git a/tests/test_parser.py b/tests/test_parser.py index 74463fd..53e1a85 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -690,6 +690,31 @@ class TestParser(unittest.TestCase): LEFT JOIN b ON a.id = b.id """ ) + + self.assertIsNotNone(query) + + query = parse_one( + """ + SELECT * + FROM a + LEFT JOIN UNNEST(ARRAY[]) + LEFT JOIN UNNEST(ARRAY[]) + LEFT JOIN UNNEST(ARRAY[]) + LEFT JOIN UNNEST(ARRAY[]) + LEFT JOIN UNNEST(ARRAY[]) + LEFT JOIN UNNEST(ARRAY[]) + LEFT JOIN UNNEST(ARRAY[]) + LEFT JOIN UNNEST(ARRAY[]) + LEFT JOIN UNNEST(ARRAY[]) + LEFT JOIN UNNEST(ARRAY[]) + LEFT JOIN UNNEST(ARRAY[]) + LEFT JOIN UNNEST(ARRAY[]) + LEFT JOIN UNNEST(ARRAY[]) + LEFT JOIN UNNEST(ARRAY[]) + LEFT JOIN UNNEST(ARRAY[]) + """ + ) + self.assertIsNotNone(query) self.assertLessEqual(time.time() - now, 0.2) diff --git a/tests/test_transpile.py b/tests/test_transpile.py index a5b1977..d588f07 100644 --- a/tests/test_transpile.py +++ b/tests/test_transpile.py @@ -156,9 +156,7 @@ SELECT * FROM foo -- comment 2 -- comment 3 SELECT * FROM foo""", - """/* comment 1 */ -/* comment 2 */ -/* comment 3 */ + """/* comment 1 */ /* comment 2 */ /* comment 3 */ SELECT * FROM foo""", @@ -182,8 +180,7 @@ line3*/ /*another comment*/ where 1=1 -- comment at the end""", * FROM tbl /* line1 line2 -line3 */ -/* another comment */ +line3 */ /* another comment */ WHERE 1 = 1 /* comment at the end */""", pretty=True, @@ -310,9 +307,7 @@ FROM v""", -- comment3 DROP TABLE IF EXISTS db.tba """, - """/* comment1 */ -/* comment2 */ -/* comment3 */ + """/* comment1 */ /* comment2 */ /* comment3 */ DROP TABLE IF EXISTS db.tba""", pretty=True, ) @@ -337,9 +332,7 @@ SELECT c FROM tb_01 WHERE - a /* comment5 */ = 1 AND b = 2 /* comment6 */ - /* and c = 1 */ - /* comment7 */""", + a /* comment5 */ = 1 AND b = 2 /* comment6 */ /* and c = 1 */ /* comment7 */""", pretty=True, ) self.validate( @@ -375,11 +368,17 @@ INNER JOIN b""", """SELECT * FROM a -/* comment 1 */ -/* comment 2 */ +/* comment 1 */ /* comment 2 */ LEFT OUTER JOIN b""", pretty=True, ) + self.validate( + "SELECT\n a /* sqlglot.meta case_sensitive */ -- noqa\nFROM tbl", + """SELECT + a /* sqlglot.meta case_sensitive */ /* noqa */ +FROM tbl""", + pretty=True, + ) def test_types(self): self.validate("INT 1", "CAST(1 AS INT)") @@ -468,12 +467,12 @@ LEFT OUTER JOIN b""", "ALTER TABLE integers ADD COLUMN k INT", ) self.validate( - "ALTER TABLE integers ALTER i SET DATA TYPE VARCHAR", - "ALTER TABLE integers ALTER COLUMN i TYPE VARCHAR", + "ALTER TABLE integers ALTER i TYPE VARCHAR", + "ALTER TABLE integers ALTER COLUMN i SET DATA TYPE VARCHAR", ) self.validate( "ALTER TABLE integers ALTER i TYPE VARCHAR COLLATE foo USING bar", - "ALTER TABLE integers ALTER COLUMN i TYPE VARCHAR COLLATE foo USING bar", + "ALTER TABLE integers ALTER COLUMN i SET DATA TYPE VARCHAR COLLATE foo USING bar", ) def test_time(self): @@ -604,7 +603,7 @@ LEFT OUTER JOIN b""", self.validate( "CREATE TEMPORARY TABLE test AS SELECT 1", "CREATE TEMPORARY VIEW test AS SELECT 1", - write="spark", + write="spark2", ) @mock.patch("sqlglot.helper.logger") |