diff options
Diffstat (limited to 'tests/dialects')
-rw-r--r-- | tests/dialects/test_bigquery.py | 44 | ||||
-rw-r--r-- | tests/dialects/test_clickhouse.py | 20 | ||||
-rw-r--r-- | tests/dialects/test_dialect.py | 11 | ||||
-rw-r--r-- | tests/dialects/test_duckdb.py | 4 | ||||
-rw-r--r-- | tests/dialects/test_hive.py | 8 | ||||
-rw-r--r-- | tests/dialects/test_mysql.py | 2 | ||||
-rw-r--r-- | tests/dialects/test_oracle.py | 1 | ||||
-rw-r--r-- | tests/dialects/test_postgres.py | 450 | ||||
-rw-r--r-- | tests/dialects/test_redshift.py | 6 | ||||
-rw-r--r-- | tests/dialects/test_snowflake.py | 46 | ||||
-rw-r--r-- | tests/dialects/test_spark.py | 20 | ||||
-rw-r--r-- | tests/dialects/test_teradata.py | 12 | ||||
-rw-r--r-- | tests/dialects/test_tsql.py | 30 |
13 files changed, 399 insertions, 255 deletions
diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py index 48480f9..2c8ac7b 100644 --- a/tests/dialects/test_bigquery.py +++ b/tests/dialects/test_bigquery.py @@ -9,6 +9,7 @@ from sqlglot import ( transpile, ) from sqlglot.helper import logger as helper_logger +from sqlglot.parser import logger as parser_logger from tests.dialects.test_dialect import Validator @@ -17,6 +18,29 @@ class TestBigQuery(Validator): maxDiff = None def test_bigquery(self): + self.validate_identity( + "create or replace view test (tenant_id OPTIONS(description='Test description on table creation')) select 1 as tenant_id, 1 as customer_id;", + "CREATE OR REPLACE VIEW test (tenant_id OPTIONS (description='Test description on table creation')) AS SELECT 1 AS tenant_id, 1 AS customer_id", + ) + + with self.assertLogs(helper_logger) as cm: + statements = parse( + """ + BEGIN + DECLARE 1; + IF from_date IS NULL THEN SET x = 1; + END IF; + END + """, + read="bigquery", + ) + self.assertIn("unsupported syntax", cm.output[0]) + + for actual, expected in zip( + statements, ("BEGIN DECLARE 1", "IF from_date IS NULL THEN SET x = 1", "END IF", "END") + ): + self.assertEqual(actual.sql(dialect="bigquery"), expected) + with self.assertLogs(helper_logger) as cm: self.validate_identity( "SELECT * FROM t AS t(c1, c2)", @@ -77,14 +101,16 @@ class TestBigQuery(Validator): with self.assertRaises(ParseError): transpile("DATE_ADD(x, day)", read="bigquery") - for_in_stmts = parse( - "FOR record IN (SELECT word FROM shakespeare) DO SELECT record.word; END FOR;", - read="bigquery", - ) - self.assertEqual( - [s.sql(dialect="bigquery") for s in for_in_stmts], - ["FOR record IN (SELECT word FROM shakespeare) DO SELECT record.word", "END FOR"], - ) + with self.assertLogs(parser_logger) as cm: + for_in_stmts = parse( + "FOR record IN (SELECT word FROM shakespeare) DO SELECT record.word; END FOR;", + read="bigquery", + ) + self.assertEqual( + [s.sql(dialect="bigquery") for s in for_in_stmts], + ["FOR record IN (SELECT word FROM shakespeare) DO SELECT record.word", "END FOR"], + ) + assert "'END FOR'" in cm.output[0] self.validate_identity("SELECT * FROM dataset.my_table TABLESAMPLE SYSTEM (10 PERCENT)") self.validate_identity("TIME('2008-12-25 15:30:00+08')") @@ -135,7 +161,7 @@ class TestBigQuery(Validator): self.validate_identity("""CREATE TABLE x (a STRUCT<b STRING OPTIONS (description='b')>)""") self.validate_identity("CAST(x AS TIMESTAMP)") self.validate_identity("REGEXP_EXTRACT(`foo`, 'bar: (.+?)', 1, 1)") - self.validate_identity("BEGIN A B C D E F") + self.validate_identity("BEGIN DECLARE y INT64", check_command_warning=True) self.validate_identity("BEGIN TRANSACTION") self.validate_identity("COMMIT TRANSACTION") self.validate_identity("ROLLBACK TRANSACTION") diff --git a/tests/dialects/test_clickhouse.py b/tests/dialects/test_clickhouse.py index 84903aa..f36af41 100644 --- a/tests/dialects/test_clickhouse.py +++ b/tests/dialects/test_clickhouse.py @@ -42,7 +42,6 @@ class TestClickhouse(Validator): self.validate_identity("SELECT isNaN(1.0)") self.validate_identity("SELECT startsWith('Spider-Man', 'Spi')") self.validate_identity("SELECT xor(TRUE, FALSE)") - self.validate_identity("ATTACH DATABASE DEFAULT ENGINE = ORDINARY") self.validate_identity("CAST(['hello'], 'Array(Enum8(''hello'' = 1))')") self.validate_identity("SELECT x, COUNT() FROM y GROUP BY x WITH TOTALS") self.validate_identity("SELECT INTERVAL t.days DAY") @@ -76,6 +75,9 @@ class TestClickhouse(Validator): self.validate_identity("CAST(x as MEDIUMINT)", "CAST(x AS Int32)") self.validate_identity("SELECT arrayJoin([1, 2, 3] AS src) AS dst, 'Hello', src") self.validate_identity( + "ATTACH DATABASE DEFAULT ENGINE = ORDINARY", check_command_warning=True + ) + self.validate_identity( "SELECT n, source FROM (SELECT toFloat32(number % 10) AS n, 'original' AS source FROM numbers(10) WHERE number % 3 = 1) ORDER BY n WITH FILL" ) self.validate_identity( @@ -728,3 +730,19 @@ LIFETIME(MIN 0 MAX 0)""", ) self.validate_identity("""CREATE TABLE ip_data (ip4 IPv4, ip6 IPv6) ENGINE=TinyLog()""") self.validate_identity("""CREATE TABLE dates (dt1 Date32) ENGINE=TinyLog()""") + self.validate_all( + """ + CREATE TABLE t ( + a AggregateFunction(quantiles(0.5, 0.9), UInt64), + b AggregateFunction(quantiles, UInt64), + c SimpleAggregateFunction(sum, Float64) + )""", + write={ + "clickhouse": """CREATE TABLE t ( + a AggregateFunction(quantiles(0.5, 0.9), UInt64), + b AggregateFunction(quantiles, UInt64), + c SimpleAggregateFunction(sum, Float64) +)""" + }, + pretty=True, + ) diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py index 3cf4ddc..22e7d49 100644 --- a/tests/dialects/test_dialect.py +++ b/tests/dialects/test_dialect.py @@ -11,6 +11,7 @@ from sqlglot import ( parse_one, ) from sqlglot.dialects import BigQuery, Hive, Snowflake +from sqlglot.parser import logger as parser_logger class Validator(unittest.TestCase): @@ -19,8 +20,14 @@ class Validator(unittest.TestCase): def parse_one(self, sql): return parse_one(sql, read=self.dialect) - def validate_identity(self, sql, write_sql=None, pretty=False): - expression = self.parse_one(sql) + def validate_identity(self, sql, write_sql=None, pretty=False, check_command_warning=False): + if check_command_warning: + with self.assertLogs(parser_logger) as cm: + expression = self.parse_one(sql) + assert f"'{sql[:100]}' contains unsupported syntax" in cm.output[0] + else: + expression = self.parse_one(sql) + self.assertEqual(write_sql or sql, expression.sql(dialect=self.dialect, pretty=pretty)) return expression diff --git a/tests/dialects/test_duckdb.py b/tests/dialects/test_duckdb.py index e5f7e0c..f3b41b4 100644 --- a/tests/dialects/test_duckdb.py +++ b/tests/dialects/test_duckdb.py @@ -196,11 +196,13 @@ class TestDuckDB(Validator): self.validate_identity("SELECT ROW(x, x + 1, y) FROM (SELECT 1 AS x, 'a' AS y)") self.validate_identity("SELECT (x, x + 1, y) FROM (SELECT 1 AS x, 'a' AS y)") self.validate_identity("SELECT a.x FROM (SELECT {'x': 1, 'y': 2, 'z': 3} AS a)") - self.validate_identity("ATTACH DATABASE ':memory:' AS new_database") self.validate_identity("FROM x SELECT x UNION SELECT 1", "SELECT x FROM x UNION SELECT 1") self.validate_identity("FROM (FROM tbl)", "SELECT * FROM (SELECT * FROM tbl)") self.validate_identity("FROM tbl", "SELECT * FROM tbl") self.validate_identity( + "ATTACH DATABASE ':memory:' AS new_database", check_command_warning=True + ) + self.validate_identity( "SELECT {'yes': 'duck', 'maybe': 'goose', 'huh': NULL, 'no': 'heron'}" ) self.validate_identity( diff --git a/tests/dialects/test_hive.py b/tests/dialects/test_hive.py index 8b5a945..d1b7589 100644 --- a/tests/dialects/test_hive.py +++ b/tests/dialects/test_hive.py @@ -399,6 +399,7 @@ class TestHive(Validator): ) def test_hive(self): + self.validate_identity("SET hiveconf:some_var = 5", check_command_warning=True) self.validate_identity("(VALUES (1 AS a, 2 AS b, 3))") self.validate_identity("SELECT * FROM my_table TIMESTAMP AS OF DATE_ADD(CURRENT_DATE, -1)") self.validate_identity("SELECT * FROM my_table VERSION AS OF DATE_ADD(CURRENT_DATE, -1)") @@ -441,13 +442,6 @@ class TestHive(Validator): ) self.validate_all( - "SET hiveconf:some_var = 5", - write={ - "hive": "SET hiveconf:some_var = 5", - "spark": "SET hiveconf:some_var = 5", - }, - ) - self.validate_all( "SELECT ${hiveconf:some_var}", write={ "hive": "SELECT ${hiveconf:some_var}", diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py index 85bf261..3a3e49e 100644 --- a/tests/dialects/test_mysql.py +++ b/tests/dialects/test_mysql.py @@ -126,7 +126,7 @@ class TestMySQL(Validator): self.validate_identity("SELECT DATE_FORMAT(NOW(), '%Y-%m-%d %H:%i:00.0000')") self.validate_identity("SELECT @var1 := 1, @var2") self.validate_identity("UNLOCK TABLES") - self.validate_identity("LOCK TABLES `app_fields` WRITE") + self.validate_identity("LOCK TABLES `app_fields` WRITE", check_command_warning=True) self.validate_identity("SELECT 1 XOR 0") self.validate_identity("SELECT 1 && 0", "SELECT 1 AND 0") self.validate_identity("SELECT /*+ BKA(t1) NO_BKA(t2) */ * FROM t1 INNER JOIN t2") diff --git a/tests/dialects/test_oracle.py b/tests/dialects/test_oracle.py index fce714e..bc8f8bb 100644 --- a/tests/dialects/test_oracle.py +++ b/tests/dialects/test_oracle.py @@ -12,6 +12,7 @@ class TestOracle(Validator): exp.AlterTable ) + self.validate_identity("TIMESTAMP(3) WITH TIME ZONE") self.validate_identity("CURRENT_TIMESTAMP(precision)") self.validate_identity("ALTER TABLE tbl_name DROP FOREIGN KEY fk_symbol") self.validate_identity("ALTER TABLE Payments ADD Stock NUMBER NOT NULL") diff --git a/tests/dialects/test_postgres.py b/tests/dialects/test_postgres.py index f46eeba..dc00c85 100644 --- a/tests/dialects/test_postgres.py +++ b/tests/dialects/test_postgres.py @@ -7,214 +7,10 @@ class TestPostgres(Validator): maxDiff = None dialect = "postgres" - def test_ddl(self): - expr = parse_one("CREATE TABLE t (x INTERVAL day)", read="postgres") - cdef = expr.find(exp.ColumnDef) - cdef.args["kind"].assert_is(exp.DataType) - self.assertEqual(expr.sql(dialect="postgres"), "CREATE TABLE t (x INTERVAL DAY)") - - self.validate_identity("CREATE INDEX idx_x ON x USING BTREE(x, y) WHERE (NOT y IS NULL)") - self.validate_identity("CREATE TABLE test (elems JSONB[])") - self.validate_identity("CREATE TABLE public.y (x TSTZRANGE NOT NULL)") - self.validate_identity("CREATE TABLE test (foo HSTORE)") - self.validate_identity("CREATE TABLE test (foo JSONB)") - self.validate_identity("CREATE TABLE test (foo VARCHAR(64)[])") - self.validate_identity("CREATE TABLE test (foo INT) PARTITION BY HASH(foo)") - self.validate_identity("INSERT INTO x VALUES (1, 'a', 2.0) RETURNING a") - self.validate_identity("INSERT INTO x VALUES (1, 'a', 2.0) RETURNING a, b") - self.validate_identity("INSERT INTO x VALUES (1, 'a', 2.0) RETURNING *") - self.validate_identity("UPDATE tbl_name SET foo = 123 RETURNING a") - self.validate_identity("CREATE TABLE cities_partdef PARTITION OF cities DEFAULT") - self.validate_identity( - "CREATE CONSTRAINT TRIGGER my_trigger AFTER INSERT OR DELETE OR UPDATE OF col_a, col_b ON public.my_table DEFERRABLE INITIALLY DEFERRED FOR EACH ROW EXECUTE FUNCTION do_sth()" - ) - self.validate_identity( - "CREATE TABLE cust_part3 PARTITION OF customers FOR VALUES WITH (MODULUS 3, REMAINDER 2)" - ) - self.validate_identity( - "CREATE TABLE measurement_y2016m07 PARTITION OF measurement (unitsales DEFAULT 0) FOR VALUES FROM ('2016-07-01') TO ('2016-08-01')" - ) - self.validate_identity( - "CREATE TABLE measurement_ym_older PARTITION OF measurement_year_month FOR VALUES FROM (MINVALUE, MINVALUE) TO (2016, 11)" - ) - self.validate_identity( - "CREATE TABLE measurement_ym_y2016m11 PARTITION OF measurement_year_month FOR VALUES FROM (2016, 11) TO (2016, 12)" - ) - self.validate_identity( - "CREATE TABLE cities_ab PARTITION OF cities (CONSTRAINT city_id_nonzero CHECK (city_id <> 0)) FOR VALUES IN ('a', 'b')" - ) - self.validate_identity( - "CREATE TABLE cities_ab PARTITION OF cities (CONSTRAINT city_id_nonzero CHECK (city_id <> 0)) FOR VALUES IN ('a', 'b') PARTITION BY RANGE(population)" - ) - self.validate_identity( - "CREATE INDEX foo ON bar.baz USING btree(col1 varchar_pattern_ops ASC, col2)" - ) - self.validate_identity( - "CREATE INDEX index_issues_on_title_trigram ON public.issues USING gin(title public.gin_trgm_ops)" - ) - self.validate_identity( - "INSERT INTO x VALUES (1, 'a', 2.0) ON CONFLICT (id) DO NOTHING RETURNING *" - ) - self.validate_identity( - "INSERT INTO x VALUES (1, 'a', 2.0) ON CONFLICT (id) DO UPDATE SET x.id = 1 RETURNING *" - ) - self.validate_identity( - "INSERT INTO x VALUES (1, 'a', 2.0) ON CONFLICT (id) DO UPDATE SET x.id = excluded.id RETURNING *" - ) - self.validate_identity( - "INSERT INTO x VALUES (1, 'a', 2.0) ON CONFLICT ON CONSTRAINT pkey DO NOTHING RETURNING *" - ) - self.validate_identity( - "INSERT INTO x VALUES (1, 'a', 2.0) ON CONFLICT ON CONSTRAINT pkey DO UPDATE SET x.id = 1 RETURNING *" - ) - self.validate_identity( - "DELETE FROM event USING sales AS s WHERE event.eventid = s.eventid RETURNING a" - ) - self.validate_identity( - "CREATE TABLE test (x TIMESTAMP WITHOUT TIME ZONE[][])", - "CREATE TABLE test (x TIMESTAMP[][])", - ) - self.validate_identity( - "CREATE UNLOGGED TABLE foo AS WITH t(c) AS (SELECT 1) SELECT * FROM (SELECT c AS c FROM t) AS temp" - ) - self.validate_identity( - "WITH t(c) AS (SELECT 1) SELECT * INTO UNLOGGED foo FROM (SELECT c AS c FROM t) AS temp" - ) - - self.validate_all( - "CREATE OR REPLACE FUNCTION function_name (input_a character varying DEFAULT NULL::character varying)", - write={ - "postgres": "CREATE OR REPLACE FUNCTION function_name(input_a VARCHAR DEFAULT CAST(NULL AS VARCHAR))", - }, - ) - self.validate_all( - "CREATE TABLE products (product_no INT UNIQUE, name TEXT, price DECIMAL)", - write={ - "postgres": "CREATE TABLE products (product_no INT UNIQUE, name TEXT, price DECIMAL)" - }, - ) - self.validate_all( - "CREATE TABLE products (product_no INT CONSTRAINT must_be_different UNIQUE, name TEXT CONSTRAINT present NOT NULL, price DECIMAL)", - write={ - "postgres": "CREATE TABLE products (product_no INT CONSTRAINT must_be_different UNIQUE, name TEXT CONSTRAINT present NOT NULL, price DECIMAL)" - }, - ) - self.validate_all( - "CREATE TABLE products (product_no INT, name TEXT, price DECIMAL, UNIQUE (product_no, name))", - write={ - "postgres": "CREATE TABLE products (product_no INT, name TEXT, price DECIMAL, UNIQUE (product_no, name))" - }, - ) - self.validate_all( - "CREATE TABLE products (" - "product_no INT UNIQUE," - " name TEXT," - " price DECIMAL CHECK (price > 0)," - " discounted_price DECIMAL CONSTRAINT positive_discount CHECK (discounted_price > 0)," - " CHECK (product_no > 1)," - " CONSTRAINT valid_discount CHECK (price > discounted_price))", - write={ - "postgres": "CREATE TABLE products (" - "product_no INT UNIQUE," - " name TEXT," - " price DECIMAL CHECK (price > 0)," - " discounted_price DECIMAL CONSTRAINT positive_discount CHECK (discounted_price > 0)," - " CHECK (product_no > 1)," - " CONSTRAINT valid_discount CHECK (price > discounted_price))" - }, - ) - self.validate_identity( - """ - CREATE INDEX index_ci_builds_on_commit_id_and_artifacts_expireatandidpartial - ON public.ci_builds - USING btree (commit_id, artifacts_expire_at, id) - WHERE ( - ((type)::text = 'Ci::Build'::text) - AND ((retried = false) OR (retried IS NULL)) - AND ((name)::text = ANY (ARRAY[ - ('sast'::character varying)::text, - ('dependency_scanning'::character varying)::text, - ('sast:container'::character varying)::text, - ('container_scanning'::character varying)::text, - ('dast'::character varying)::text - ])) - ) - """, - "CREATE INDEX index_ci_builds_on_commit_id_and_artifacts_expireatandidpartial ON public.ci_builds USING btree(commit_id, artifacts_expire_at, id) WHERE ((CAST((type) AS TEXT) = CAST('Ci::Build' AS TEXT)) AND ((retried = FALSE) OR (retried IS NULL)) AND (CAST((name) AS TEXT) = ANY (ARRAY[CAST((CAST('sast' AS VARCHAR)) AS TEXT), CAST((CAST('dependency_scanning' AS VARCHAR)) AS TEXT), CAST((CAST('sast:container' AS VARCHAR)) AS TEXT), CAST((CAST('container_scanning' AS VARCHAR)) AS TEXT), CAST((CAST('dast' AS VARCHAR)) AS TEXT)])))", - ) - self.validate_identity( - "CREATE INDEX index_ci_pipelines_on_project_idandrefandiddesc ON public.ci_pipelines USING btree(project_id, ref, id DESC)" - ) - - with self.assertRaises(ParseError): - transpile("CREATE TABLE products (price DECIMAL CHECK price > 0)", read="postgres") - with self.assertRaises(ParseError): - transpile( - "CREATE TABLE products (price DECIMAL, CHECK price > 1)", - read="postgres", - ) - - def test_unnest(self): - self.validate_identity( - "SELECT * FROM UNNEST(ARRAY[1, 2], ARRAY['foo', 'bar', 'baz']) AS x(a, b)" - ) - - self.validate_all( - "SELECT UNNEST(c) FROM t", - write={ - "hive": "SELECT EXPLODE(c) FROM t", - "postgres": "SELECT UNNEST(c) FROM t", - "presto": "SELECT IF(_u.pos = _u_2.pos_2, _u_2.col) AS col FROM t CROSS JOIN UNNEST(SEQUENCE(1, GREATEST(CARDINALITY(c)))) AS _u(pos) CROSS JOIN UNNEST(c) WITH ORDINALITY AS _u_2(col, pos_2) WHERE _u.pos = _u_2.pos_2 OR (_u.pos > CARDINALITY(c) AND _u_2.pos_2 = CARDINALITY(c))", - }, - ) - self.validate_all( - "SELECT UNNEST(ARRAY[1])", - write={ - "hive": "SELECT EXPLODE(ARRAY(1))", - "postgres": "SELECT UNNEST(ARRAY[1])", - "presto": "SELECT IF(_u.pos = _u_2.pos_2, _u_2.col) AS col FROM UNNEST(SEQUENCE(1, GREATEST(CARDINALITY(ARRAY[1])))) AS _u(pos) CROSS JOIN UNNEST(ARRAY[1]) WITH ORDINALITY AS _u_2(col, pos_2) WHERE _u.pos = _u_2.pos_2 OR (_u.pos > CARDINALITY(ARRAY[1]) AND _u_2.pos_2 = CARDINALITY(ARRAY[1]))", - }, - ) - - def test_array_offset(self): - with self.assertLogs(helper_logger) as cm: - self.validate_all( - "SELECT col[1]", - write={ - "bigquery": "SELECT col[0]", - "duckdb": "SELECT col[1]", - "hive": "SELECT col[0]", - "postgres": "SELECT col[1]", - "presto": "SELECT col[1]", - }, - ) - - self.assertEqual( - cm.output, - [ - "WARNING:sqlglot:Applying array index offset (-1)", - "WARNING:sqlglot:Applying array index offset (1)", - "WARNING:sqlglot:Applying array index offset (1)", - "WARNING:sqlglot:Applying array index offset (1)", - ], - ) - - def test_operator(self): - expr = parse_one("1 OPERATOR(+) 2 OPERATOR(*) 3", read="postgres") - - expr.left.assert_is(exp.Operator) - expr.left.left.assert_is(exp.Literal) - expr.left.right.assert_is(exp.Literal) - expr.right.assert_is(exp.Literal) - self.assertEqual(expr.sql(dialect="postgres"), "1 OPERATOR(+) 2 OPERATOR(*) 3") - - self.validate_identity("SELECT operator FROM t") - self.validate_identity("SELECT 1 OPERATOR(+) 2") - self.validate_identity("SELECT 1 OPERATOR(+) /* foo */ 2") - self.validate_identity("SELECT 1 OPERATOR(pg_catalog.+) 2") - def test_postgres(self): - self.validate_identity("EXEC AS myfunc @id = 123") + self.validate_identity("SELECT CURRENT_USER") + self.validate_identity("CAST(1 AS DECIMAL) / CAST(2 AS DECIMAL) * -100") + self.validate_identity("EXEC AS myfunc @id = 123", check_command_warning=True) expr = parse_one( "SELECT * FROM r CROSS JOIN LATERAL UNNEST(ARRAY[1]) AS s(location)", read="postgres" @@ -782,6 +578,246 @@ class TestPostgres(Validator): self.assertIsInstance(parse_one("id::UUID", read="postgres"), exp.Cast) + def test_ddl(self): + expr = parse_one("CREATE TABLE t (x INTERVAL day)", read="postgres") + cdef = expr.find(exp.ColumnDef) + cdef.args["kind"].assert_is(exp.DataType) + self.assertEqual(expr.sql(dialect="postgres"), "CREATE TABLE t (x INTERVAL DAY)") + + self.validate_identity("CREATE INDEX et_vid_idx ON et(vid) INCLUDE (fid)") + self.validate_identity("CREATE INDEX idx_x ON x USING BTREE(x, y) WHERE (NOT y IS NULL)") + self.validate_identity("CREATE TABLE test (elems JSONB[])") + self.validate_identity("CREATE TABLE public.y (x TSTZRANGE NOT NULL)") + self.validate_identity("CREATE TABLE test (foo HSTORE)") + self.validate_identity("CREATE TABLE test (foo JSONB)") + self.validate_identity("CREATE TABLE test (foo VARCHAR(64)[])") + self.validate_identity("CREATE TABLE test (foo INT) PARTITION BY HASH(foo)") + self.validate_identity("INSERT INTO x VALUES (1, 'a', 2.0) RETURNING a") + self.validate_identity("INSERT INTO x VALUES (1, 'a', 2.0) RETURNING a, b") + self.validate_identity("INSERT INTO x VALUES (1, 'a', 2.0) RETURNING *") + self.validate_identity("UPDATE tbl_name SET foo = 123 RETURNING a") + self.validate_identity("CREATE TABLE cities_partdef PARTITION OF cities DEFAULT") + self.validate_identity("CREATE TABLE t (c CHAR(2) UNIQUE NOT NULL) INHERITS (t1)") + self.validate_identity("CREATE TABLE s.t (c CHAR(2) UNIQUE NOT NULL) INHERITS (s.t1, s.t2)") + self.validate_identity("CREATE FUNCTION x(INT) RETURNS INT SET search_path = 'public'") + self.validate_identity( + "CREATE TABLE cust_part3 PARTITION OF customers FOR VALUES WITH (MODULUS 3, REMAINDER 2)" + ) + self.validate_identity( + "CREATE TABLE measurement_y2016m07 PARTITION OF measurement (unitsales DEFAULT 0) FOR VALUES FROM ('2016-07-01') TO ('2016-08-01')" + ) + self.validate_identity( + "CREATE TABLE measurement_ym_older PARTITION OF measurement_year_month FOR VALUES FROM (MINVALUE, MINVALUE) TO (2016, 11)" + ) + self.validate_identity( + "CREATE TABLE measurement_ym_y2016m11 PARTITION OF measurement_year_month FOR VALUES FROM (2016, 11) TO (2016, 12)" + ) + self.validate_identity( + "CREATE TABLE cities_ab PARTITION OF cities (CONSTRAINT city_id_nonzero CHECK (city_id <> 0)) FOR VALUES IN ('a', 'b')" + ) + self.validate_identity( + "CREATE TABLE cities_ab PARTITION OF cities (CONSTRAINT city_id_nonzero CHECK (city_id <> 0)) FOR VALUES IN ('a', 'b') PARTITION BY RANGE(population)" + ) + self.validate_identity( + "CREATE INDEX foo ON bar.baz USING btree(col1 varchar_pattern_ops ASC, col2)" + ) + self.validate_identity( + "CREATE INDEX index_issues_on_title_trigram ON public.issues USING gin(title public.gin_trgm_ops)" + ) + self.validate_identity( + "INSERT INTO x VALUES (1, 'a', 2.0) ON CONFLICT (id) DO NOTHING RETURNING *" + ) + self.validate_identity( + "INSERT INTO x VALUES (1, 'a', 2.0) ON CONFLICT (id) DO UPDATE SET x.id = 1 RETURNING *" + ) + self.validate_identity( + "INSERT INTO x VALUES (1, 'a', 2.0) ON CONFLICT (id) DO UPDATE SET x.id = excluded.id RETURNING *" + ) + self.validate_identity( + "INSERT INTO x VALUES (1, 'a', 2.0) ON CONFLICT ON CONSTRAINT pkey DO NOTHING RETURNING *" + ) + self.validate_identity( + "INSERT INTO x VALUES (1, 'a', 2.0) ON CONFLICT ON CONSTRAINT pkey DO UPDATE SET x.id = 1 RETURNING *" + ) + self.validate_identity( + "DELETE FROM event USING sales AS s WHERE event.eventid = s.eventid RETURNING a" + ) + self.validate_identity( + "WITH t(c) AS (SELECT 1) SELECT * INTO UNLOGGED foo FROM (SELECT c AS c FROM t) AS temp" + ) + self.validate_identity( + "CREATE TABLE test (x TIMESTAMP WITHOUT TIME ZONE[][])", + "CREATE TABLE test (x TIMESTAMP[][])", + ) + self.validate_identity( + "CREATE FUNCTION add(INT, INT) RETURNS INT SET search_path TO 'public' AS 'select $1 + $2;' LANGUAGE SQL IMMUTABLE", + check_command_warning=True, + ) + self.validate_identity( + "CREATE FUNCTION x(INT) RETURNS INT SET foo FROM CURRENT", + check_command_warning=True, + ) + self.validate_identity( + "CREATE FUNCTION add(integer, integer) RETURNS integer AS 'select $1 + $2;' LANGUAGE SQL IMMUTABLE RETURNS NULL ON NULL INPUT", + check_command_warning=True, + ) + self.validate_identity( + "CREATE FUNCTION add(integer, integer) RETURNS integer AS 'select $1 + $2;' LANGUAGE SQL IMMUTABLE CALLED ON NULL INPUT", + check_command_warning=True, + ) + self.validate_identity( + "CREATE FUNCTION add(integer, integer) RETURNS integer AS 'select $1 + $2;' LANGUAGE SQL IMMUTABLE STRICT", + check_command_warning=True, + ) + self.validate_identity( + "CREATE CONSTRAINT TRIGGER my_trigger AFTER INSERT OR DELETE OR UPDATE OF col_a, col_b ON public.my_table DEFERRABLE INITIALLY DEFERRED FOR EACH ROW EXECUTE FUNCTION do_sth()", + check_command_warning=True, + ) + self.validate_identity( + "CREATE UNLOGGED TABLE foo AS WITH t(c) AS (SELECT 1) SELECT * FROM (SELECT c AS c FROM t) AS temp", + check_command_warning=True, + ) + self.validate_identity( + "CREATE FUNCTION x(INT) RETURNS INT SET search_path TO 'public'", + "CREATE FUNCTION x(INT) RETURNS INT SET search_path = 'public'", + ) + self.validate_identity( + "CREATE TABLE test (x TIMESTAMP WITHOUT TIME ZONE[][])", + "CREATE TABLE test (x TIMESTAMP[][])", + ) + + self.validate_all( + "CREATE OR REPLACE FUNCTION function_name (input_a character varying DEFAULT NULL::character varying)", + write={ + "postgres": "CREATE OR REPLACE FUNCTION function_name(input_a VARCHAR DEFAULT CAST(NULL AS VARCHAR))", + }, + ) + self.validate_all( + "CREATE TABLE products (product_no INT UNIQUE, name TEXT, price DECIMAL)", + write={ + "postgres": "CREATE TABLE products (product_no INT UNIQUE, name TEXT, price DECIMAL)" + }, + ) + self.validate_all( + "CREATE TABLE products (product_no INT CONSTRAINT must_be_different UNIQUE, name TEXT CONSTRAINT present NOT NULL, price DECIMAL)", + write={ + "postgres": "CREATE TABLE products (product_no INT CONSTRAINT must_be_different UNIQUE, name TEXT CONSTRAINT present NOT NULL, price DECIMAL)" + }, + ) + self.validate_all( + "CREATE TABLE products (product_no INT, name TEXT, price DECIMAL, UNIQUE (product_no, name))", + write={ + "postgres": "CREATE TABLE products (product_no INT, name TEXT, price DECIMAL, UNIQUE (product_no, name))" + }, + ) + self.validate_all( + "CREATE TABLE products (" + "product_no INT UNIQUE," + " name TEXT," + " price DECIMAL CHECK (price > 0)," + " discounted_price DECIMAL CONSTRAINT positive_discount CHECK (discounted_price > 0)," + " CHECK (product_no > 1)," + " CONSTRAINT valid_discount CHECK (price > discounted_price))", + write={ + "postgres": "CREATE TABLE products (" + "product_no INT UNIQUE," + " name TEXT," + " price DECIMAL CHECK (price > 0)," + " discounted_price DECIMAL CONSTRAINT positive_discount CHECK (discounted_price > 0)," + " CHECK (product_no > 1)," + " CONSTRAINT valid_discount CHECK (price > discounted_price))" + }, + ) + self.validate_identity( + """ + CREATE INDEX index_ci_builds_on_commit_id_and_artifacts_expireatandidpartial + ON public.ci_builds + USING btree (commit_id, artifacts_expire_at, id) + WHERE ( + ((type)::text = 'Ci::Build'::text) + AND ((retried = false) OR (retried IS NULL)) + AND ((name)::text = ANY (ARRAY[ + ('sast'::character varying)::text, + ('dependency_scanning'::character varying)::text, + ('sast:container'::character varying)::text, + ('container_scanning'::character varying)::text, + ('dast'::character varying)::text + ])) + ) + """, + "CREATE INDEX index_ci_builds_on_commit_id_and_artifacts_expireatandidpartial ON public.ci_builds USING btree(commit_id, artifacts_expire_at, id) WHERE ((CAST((type) AS TEXT) = CAST('Ci::Build' AS TEXT)) AND ((retried = FALSE) OR (retried IS NULL)) AND (CAST((name) AS TEXT) = ANY (ARRAY[CAST((CAST('sast' AS VARCHAR)) AS TEXT), CAST((CAST('dependency_scanning' AS VARCHAR)) AS TEXT), CAST((CAST('sast:container' AS VARCHAR)) AS TEXT), CAST((CAST('container_scanning' AS VARCHAR)) AS TEXT), CAST((CAST('dast' AS VARCHAR)) AS TEXT)])))", + ) + self.validate_identity( + "CREATE INDEX index_ci_pipelines_on_project_idandrefandiddesc ON public.ci_pipelines USING btree(project_id, ref, id DESC)" + ) + + with self.assertRaises(ParseError): + transpile("CREATE TABLE products (price DECIMAL CHECK price > 0)", read="postgres") + with self.assertRaises(ParseError): + transpile( + "CREATE TABLE products (price DECIMAL, CHECK price > 1)", + read="postgres", + ) + + def test_unnest(self): + self.validate_identity( + "SELECT * FROM UNNEST(ARRAY[1, 2], ARRAY['foo', 'bar', 'baz']) AS x(a, b)" + ) + + self.validate_all( + "SELECT UNNEST(c) FROM t", + write={ + "hive": "SELECT EXPLODE(c) FROM t", + "postgres": "SELECT UNNEST(c) FROM t", + "presto": "SELECT IF(_u.pos = _u_2.pos_2, _u_2.col) AS col FROM t CROSS JOIN UNNEST(SEQUENCE(1, GREATEST(CARDINALITY(c)))) AS _u(pos) CROSS JOIN UNNEST(c) WITH ORDINALITY AS _u_2(col, pos_2) WHERE _u.pos = _u_2.pos_2 OR (_u.pos > CARDINALITY(c) AND _u_2.pos_2 = CARDINALITY(c))", + }, + ) + self.validate_all( + "SELECT UNNEST(ARRAY[1])", + write={ + "hive": "SELECT EXPLODE(ARRAY(1))", + "postgres": "SELECT UNNEST(ARRAY[1])", + "presto": "SELECT IF(_u.pos = _u_2.pos_2, _u_2.col) AS col FROM UNNEST(SEQUENCE(1, GREATEST(CARDINALITY(ARRAY[1])))) AS _u(pos) CROSS JOIN UNNEST(ARRAY[1]) WITH ORDINALITY AS _u_2(col, pos_2) WHERE _u.pos = _u_2.pos_2 OR (_u.pos > CARDINALITY(ARRAY[1]) AND _u_2.pos_2 = CARDINALITY(ARRAY[1]))", + }, + ) + + def test_array_offset(self): + with self.assertLogs(helper_logger) as cm: + self.validate_all( + "SELECT col[1]", + write={ + "bigquery": "SELECT col[0]", + "duckdb": "SELECT col[1]", + "hive": "SELECT col[0]", + "postgres": "SELECT col[1]", + "presto": "SELECT col[1]", + }, + ) + + self.assertEqual( + cm.output, + [ + "WARNING:sqlglot:Applying array index offset (-1)", + "WARNING:sqlglot:Applying array index offset (1)", + "WARNING:sqlglot:Applying array index offset (1)", + "WARNING:sqlglot:Applying array index offset (1)", + ], + ) + + def test_operator(self): + expr = parse_one("1 OPERATOR(+) 2 OPERATOR(*) 3", read="postgres") + + expr.left.assert_is(exp.Operator) + expr.left.left.assert_is(exp.Literal) + expr.left.right.assert_is(exp.Literal) + expr.right.assert_is(exp.Literal) + self.assertEqual(expr.sql(dialect="postgres"), "1 OPERATOR(+) 2 OPERATOR(*) 3") + + self.validate_identity("SELECT operator FROM t") + self.validate_identity("SELECT 1 OPERATOR(+) 2") + self.validate_identity("SELECT 1 OPERATOR(+) /* foo */ 2") + self.validate_identity("SELECT 1 OPERATOR(pg_catalog.+) 2") + def test_bool_or(self): self.validate_all( "SELECT a, LOGICAL_OR(b) FROM table GROUP BY a", diff --git a/tests/dialects/test_redshift.py b/tests/dialects/test_redshift.py index 88fef67..9ccd955 100644 --- a/tests/dialects/test_redshift.py +++ b/tests/dialects/test_redshift.py @@ -332,10 +332,12 @@ class TestRedshift(Validator): "CREATE TABLE sales (salesid INTEGER NOT NULL) DISTKEY(listid) COMPOUND SORTKEY(listid, sellerid) DISTSTYLE AUTO" ) self.validate_identity( - "COPY customer FROM 's3://mybucket/customer' IAM_ROLE 'arn:aws:iam::0123456789012:role/MyRedshiftRole'" + "COPY customer FROM 's3://mybucket/customer' IAM_ROLE 'arn:aws:iam::0123456789012:role/MyRedshiftRole'", + check_command_warning=True, ) self.validate_identity( - "UNLOAD ('select * from venue') TO 's3://mybucket/unload/' IAM_ROLE 'arn:aws:iam::0123456789012:role/MyRedshiftRole'" + "UNLOAD ('select * from venue') TO 's3://mybucket/unload/' IAM_ROLE 'arn:aws:iam::0123456789012:role/MyRedshiftRole'", + check_command_warning=True, ) self.validate_identity( "CREATE TABLE SOUP (SOUP1 VARCHAR(50) NOT NULL ENCODE ZSTD, SOUP2 VARCHAR(70) NULL ENCODE DELTA)" diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py index 0882290..7e41fd4 100644 --- a/tests/dialects/test_snowflake.py +++ b/tests/dialects/test_snowflake.py @@ -10,6 +10,7 @@ class TestSnowflake(Validator): dialect = "snowflake" def test_snowflake(self): + self.validate_identity("ALTER TABLE table1 CLUSTER BY (name DESC)") self.validate_identity( "INSERT OVERWRITE TABLE t SELECT 1", "INSERT OVERWRITE INTO t SELECT 1" ) @@ -39,8 +40,8 @@ WHERE )""", ) - self.validate_identity("RM @parquet_stage") - self.validate_identity("REMOVE @parquet_stage") + self.validate_identity("RM @parquet_stage", check_command_warning=True) + self.validate_identity("REMOVE @parquet_stage", check_command_warning=True) self.validate_identity("SELECT TIMESTAMP_FROM_PARTS(d, t)") self.validate_identity("SELECT GET_PATH(v, 'attr[0].name') FROM vartab") self.validate_identity("SELECT TO_ARRAY(CAST(x AS ARRAY))") @@ -84,6 +85,10 @@ WHERE "SELECT a FROM test PIVOT(SUM(x) FOR y IN ('z', 'q')) AS x TABLESAMPLE (0.1)" ) self.validate_identity( + "SELECT p FROM t WHERE p:val NOT IN ('2')", + "SELECT p FROM t WHERE NOT GET_PATH(p, 'val') IN ('2')", + ) + self.validate_identity( """SELECT PARSE_JSON('{"x": "hello"}'):x LIKE 'hello'""", """SELECT GET_PATH(PARSE_JSON('{"x": "hello"}'), 'x') LIKE 'hello'""", ) @@ -777,9 +782,10 @@ WHERE self.validate_identity("SELECT * FROM @namespace.mystage/path/to/file.json.gz") self.validate_identity("SELECT * FROM @namespace.%table_name/path/to/file.json.gz") self.validate_identity("SELECT * FROM '@external/location' (FILE_FORMAT => 'path.to.csv')") - self.validate_identity("PUT file:///dir/tmp.csv @%table") + self.validate_identity("PUT file:///dir/tmp.csv @%table", check_command_warning=True) self.validate_identity( - 'COPY INTO NEW_TABLE ("foo", "bar") FROM (SELECT $1, $2, $3, $4 FROM @%old_table)' + 'COPY INTO NEW_TABLE ("foo", "bar") FROM (SELECT $1, $2, $3, $4 FROM @%old_table)', + check_command_warning=True, ) self.validate_identity( "SELECT * FROM @foo/bar (FILE_FORMAT => ds_sandbox.test.my_csv_format, PATTERN => 'test') AS bla" @@ -1095,7 +1101,7 @@ WHERE ) def test_stored_procedures(self): - self.validate_identity("CALL a.b.c(x, y)") + self.validate_identity("CALL a.b.c(x, y)", check_command_warning=True) self.validate_identity( "CREATE PROCEDURE a.b.c(x INT, y VARIANT) RETURNS OBJECT EXECUTE AS CALLER AS 'BEGIN SELECT 1; END;'" ) @@ -1449,10 +1455,10 @@ MATCH_RECOGNIZE ( def test_show(self): # Parsed as Command - self.validate_identity("SHOW TABLES LIKE 'line%' IN tpch.public") - - ast = parse_one("SHOW TABLES HISTORY IN tpch.public", read="snowflake") - self.assertIsInstance(ast, exp.Command) + self.validate_identity( + "SHOW TABLES LIKE 'line%' IN tpch.public", check_command_warning=True + ) + self.validate_identity("SHOW TABLES HISTORY IN tpch.public", check_command_warning=True) # Parsed as Show self.validate_identity("SHOW PRIMARY KEYS") @@ -1469,6 +1475,18 @@ MATCH_RECOGNIZE ( 'SHOW TERSE PRIMARY KEYS IN "TEST"."PUBLIC"."customers"', 'SHOW PRIMARY KEYS IN TABLE "TEST"."PUBLIC"."customers"', ) + self.validate_identity( + "show terse schemas in database db1 starts with 'a' limit 10 from 'b'", + "SHOW TERSE SCHEMAS IN DATABASE db1 STARTS WITH 'a' LIMIT 10 FROM 'b'", + ) + self.validate_identity( + "show terse objects in schema db1.schema1 starts with 'a' limit 10 from 'b'", + "SHOW TERSE OBJECTS IN SCHEMA db1.schema1 STARTS WITH 'a' LIMIT 10 FROM 'b'", + ) + self.validate_identity( + "show terse objects in db1.schema1 starts with 'a' limit 10 from 'b'", + "SHOW TERSE OBJECTS IN SCHEMA db1.schema1 STARTS WITH 'a' LIMIT 10 FROM 'b'", + ) ast = parse_one('SHOW PRIMARY KEYS IN "TEST"."PUBLIC"."customers"', read="snowflake") table = ast.find(exp.Table) @@ -1489,6 +1507,16 @@ MATCH_RECOGNIZE ( self.assertEqual(literal.sql(dialect="snowflake"), "'_testing%'") + ast = parse_one("SHOW SCHEMAS IN DATABASE db1", read="snowflake") + self.assertEqual(ast.args.get("scope_kind"), "DATABASE") + table = ast.find(exp.Table) + self.assertEqual(table.sql(dialect="snowflake"), "db1") + + ast = parse_one("SHOW OBJECTS IN db1.schema1", read="snowflake") + self.assertEqual(ast.args.get("scope_kind"), "SCHEMA") + table = ast.find(exp.Table) + self.assertEqual(table.sql(dialect="snowflake"), "db1.schema1") + def test_swap(self): ast = parse_one("ALTER TABLE a SWAP WITH b", read="snowflake") assert isinstance(ast, exp.AlterTable) diff --git a/tests/dialects/test_spark.py b/tests/dialects/test_spark.py index 56a573a..6044037 100644 --- a/tests/dialects/test_spark.py +++ b/tests/dialects/test_spark.py @@ -227,7 +227,6 @@ TBLPROPERTIES ( ) def test_spark(self): - self.validate_identity("FROM_UTC_TIMESTAMP(CAST(x AS TIMESTAMP), 'utc')") expr = parse_one("any_value(col, true)", read="spark") self.assertIsInstance(expr.args.get("ignore_nulls"), exp.Boolean) self.assertEqual(expr.sql(dialect="spark"), "ANY_VALUE(col, TRUE)") @@ -277,6 +276,25 @@ TBLPROPERTIES ( ) self.validate_all( + "SELECT TO_UTC_TIMESTAMP('2016-08-31', 'Asia/Seoul')", + write={ + "bigquery": "SELECT DATETIME(TIMESTAMP(CAST('2016-08-31' AS DATETIME), 'Asia/Seoul'), 'UTC')", + "duckdb": "SELECT CAST('2016-08-31' AS TIMESTAMP) AT TIME ZONE 'Asia/Seoul' AT TIME ZONE 'UTC'", + "postgres": "SELECT CAST('2016-08-31' AS TIMESTAMP) AT TIME ZONE 'Asia/Seoul' AT TIME ZONE 'UTC'", + "presto": "SELECT WITH_TIMEZONE(CAST('2016-08-31' AS TIMESTAMP), 'Asia/Seoul') AT TIME ZONE 'UTC'", + "redshift": "SELECT CAST('2016-08-31' AS TIMESTAMP) AT TIME ZONE 'Asia/Seoul' AT TIME ZONE 'UTC'", + "snowflake": "SELECT CONVERT_TIMEZONE('Asia/Seoul', 'UTC', CAST('2016-08-31' AS TIMESTAMPNTZ))", + "spark": "SELECT TO_UTC_TIMESTAMP(CAST('2016-08-31' AS TIMESTAMP), 'Asia/Seoul')", + }, + ) + self.validate_all( + "SELECT FROM_UTC_TIMESTAMP('2016-08-31', 'Asia/Seoul')", + write={ + "presto": "SELECT CAST('2016-08-31' AS TIMESTAMP) AT TIME ZONE 'Asia/Seoul'", + "spark": "SELECT FROM_UTC_TIMESTAMP(CAST('2016-08-31' AS TIMESTAMP), 'Asia/Seoul')", + }, + ) + self.validate_all( "foo.bar", read={ "": "STRUCT_EXTRACT(foo, bar)", diff --git a/tests/dialects/test_teradata.py b/tests/dialects/test_teradata.py index 85d4ebf..f3894fd 100644 --- a/tests/dialects/test_teradata.py +++ b/tests/dialects/test_teradata.py @@ -42,11 +42,13 @@ class TestTeradata(Validator): ) def test_statistics(self): - self.validate_identity("COLLECT STATISTICS ON tbl INDEX(col)") - self.validate_identity("COLLECT STATS ON tbl COLUMNS(col)") - self.validate_identity("COLLECT STATS COLUMNS(col) ON tbl") - self.validate_identity("HELP STATISTICS personel.employee") - self.validate_identity("HELP STATISTICS personnel.employee FROM my_qcd") + self.validate_identity("COLLECT STATISTICS ON tbl INDEX(col)", check_command_warning=True) + self.validate_identity("COLLECT STATS ON tbl COLUMNS(col)", check_command_warning=True) + self.validate_identity("COLLECT STATS COLUMNS(col) ON tbl", check_command_warning=True) + self.validate_identity("HELP STATISTICS personel.employee", check_command_warning=True) + self.validate_identity( + "HELP STATISTICS personnel.employee FROM my_qcd", check_command_warning=True + ) def test_create(self): self.validate_identity( diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py index 7cf9971..101d356 100644 --- a/tests/dialects/test_tsql.py +++ b/tests/dialects/test_tsql.py @@ -1,4 +1,5 @@ from sqlglot import exp, parse, parse_one +from sqlglot.parser import logger as parser_logger from tests.dialects.test_dialect import Validator @@ -7,7 +8,7 @@ class TestTSQL(Validator): def test_tsql(self): self.validate_identity("ROUND(x, 1, 0)") - self.validate_identity("EXEC MyProc @id=7, @name='Lochristi'") + self.validate_identity("EXEC MyProc @id=7, @name='Lochristi'", check_command_warning=True) # https://learn.microsoft.com/en-us/previous-versions/sql/sql-server-2008-r2/ms187879(v=sql.105)?redirectedfrom=MSDN # tsql allows .. which means use the default schema self.validate_identity("SELECT * FROM a..b") @@ -225,7 +226,7 @@ class TestTSQL(Validator): "MERGE INTO mytable WITH (HOLDLOCK) AS T USING mytable_merge AS S " "ON (T.user_id = S.user_id) WHEN NOT MATCHED THEN INSERT (c1, c2) VALUES (S.c1, S.c2)" ) - self.validate_identity("UPDATE STATISTICS x") + self.validate_identity("UPDATE STATISTICS x", check_command_warning=True) self.validate_identity("UPDATE x SET y = 1 OUTPUT x.a, x.b INTO @y FROM y") self.validate_identity("UPDATE x SET y = 1 OUTPUT x.a, x.b FROM y") self.validate_identity("INSERT INTO x (y) OUTPUT x.a, x.b INTO l SELECT * FROM z") @@ -238,14 +239,16 @@ class TestTSQL(Validator): self.validate_identity("END") self.validate_identity("@x") self.validate_identity("#x") - self.validate_identity("DECLARE @TestVariable AS VARCHAR(100)='Save Our Planet'") - self.validate_identity("PRINT @TestVariable") + self.validate_identity("PRINT @TestVariable", check_command_warning=True) self.validate_identity("SELECT Employee_ID, Department_ID FROM @MyTableVar") self.validate_identity("INSERT INTO @TestTable VALUES (1, 'Value1', 12, 20)") self.validate_identity("SELECT * FROM #foo") self.validate_identity("SELECT * FROM ##foo") self.validate_identity("SELECT a = 1", "SELECT 1 AS a") self.validate_identity( + "DECLARE @TestVariable AS VARCHAR(100)='Save Our Planet'", check_command_warning=True + ) + self.validate_identity( "SELECT a = 1 UNION ALL SELECT a = b", "SELECT 1 AS a UNION ALL SELECT b AS a" ) self.validate_identity( @@ -789,7 +792,8 @@ class TestTSQL(Validator): def test_udf(self): self.validate_identity( - "DECLARE @DWH_DateCreated DATETIME = CONVERT(DATETIME, getdate(), 104)" + "DECLARE @DWH_DateCreated DATETIME = CONVERT(DATETIME, getdate(), 104)", + check_command_warning=True, ) self.validate_identity( "CREATE PROCEDURE foo @a INTEGER, @b INTEGER AS SELECT @a = SUM(bla) FROM baz AS bar" @@ -882,8 +886,9 @@ WHERE "END", ] - for expr, expected_sql in zip(parse(sql, read="tsql"), expected_sqls): - self.assertEqual(expr.sql(dialect="tsql"), expected_sql) + with self.assertLogs(parser_logger) as cm: + for expr, expected_sql in zip(parse(sql, read="tsql"), expected_sqls): + self.assertEqual(expr.sql(dialect="tsql"), expected_sql) sql = """ CREATE PROC [dbo].[transform_proc] AS @@ -902,8 +907,9 @@ WHERE "CREATE TABLE [target_schema].[target_table] (a INTEGER) WITH (DISTRIBUTION=REPLICATE, HEAP)", ] - for expr, expected_sql in zip(parse(sql, read="tsql"), expected_sqls): - self.assertEqual(expr.sql(dialect="tsql"), expected_sql) + with self.assertLogs(parser_logger) as cm: + for expr, expected_sql in zip(parse(sql, read="tsql"), expected_sqls): + self.assertEqual(expr.sql(dialect="tsql"), expected_sql) def test_charindex(self): self.validate_all( @@ -932,7 +938,11 @@ WHERE ) def test_len(self): - self.validate_all("LEN(x)", read={"": "LENGTH(x)"}, write={"spark": "LENGTH(x)"}) + self.validate_all( + "LEN(x)", read={"": "LENGTH(x)"}, write={"spark": "LENGTH(CAST(x AS STRING))"} + ) + self.validate_all("LEN(1)", write={"tsql": "LEN(1)", "spark": "LENGTH(CAST(1 AS STRING))"}) + self.validate_all("LEN('x')", write={"tsql": "LEN('x')", "spark": "LENGTH('x')"}) def test_replicate(self): self.validate_all("REPLICATE('x', 2)", write={"spark": "REPEAT('x', 2)"}) |