diff options
Diffstat (limited to 'tests/dialects')
-rw-r--r-- | tests/dialects/test_clickhouse.py | 7 | ||||
-rw-r--r-- | tests/dialects/test_dialect.py | 66 | ||||
-rw-r--r-- | tests/dialects/test_mysql.py | 9 | ||||
-rw-r--r-- | tests/dialects/test_snowflake.py | 18 | ||||
-rw-r--r-- | tests/dialects/test_tsql.py | 142 |
5 files changed, 216 insertions, 26 deletions
diff --git a/tests/dialects/test_clickhouse.py b/tests/dialects/test_clickhouse.py index 2827dd4..905e1f4 100644 --- a/tests/dialects/test_clickhouse.py +++ b/tests/dialects/test_clickhouse.py @@ -17,7 +17,8 @@ class TestClickhouse(Validator): self.validate_identity("SELECT quantile(0.5)(a)") self.validate_identity("SELECT quantiles(0.5)(a) AS x FROM t") self.validate_identity("SELECT * FROM foo WHERE x GLOBAL IN (SELECT * FROM bar)") - self.validate_identity("position(a, b)") + self.validate_identity("position(haystack, needle)") + self.validate_identity("position(haystack, needle, position)") self.validate_all( "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname", @@ -48,6 +49,10 @@ class TestClickhouse(Validator): "clickhouse": "SELECT quantileIf(0.5)(a, TRUE)", }, ) + self.validate_all( + "SELECT position(needle IN haystack)", + write={"clickhouse": "SELECT position(haystack, needle)"}, + ) def test_cte(self): self.validate_identity("WITH 'x' AS foo SELECT foo") diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py index b2f4676..f1144ce 100644 --- a/tests/dialects/test_dialect.py +++ b/tests/dialects/test_dialect.py @@ -950,40 +950,40 @@ class TestDialect(Validator): }, ) self.validate_all( - "POSITION(' ' in x)", + "POSITION(needle in haystack)", write={ - "drill": "STRPOS(x, ' ')", - "duckdb": "STRPOS(x, ' ')", - "postgres": "STRPOS(x, ' ')", - "presto": "STRPOS(x, ' ')", - "spark": "LOCATE(' ', x)", - "clickhouse": "position(x, ' ')", - "snowflake": "POSITION(' ', x)", - "mysql": "LOCATE(' ', x)", + "drill": "STRPOS(haystack, needle)", + "duckdb": "STRPOS(haystack, needle)", + "postgres": "STRPOS(haystack, needle)", + "presto": "STRPOS(haystack, needle)", + "spark": "LOCATE(needle, haystack)", + "clickhouse": "position(haystack, needle)", + "snowflake": "POSITION(needle, haystack)", + "mysql": "LOCATE(needle, haystack)", }, ) self.validate_all( - "STR_POSITION(x, 'a')", + "STR_POSITION(haystack, needle)", write={ - "drill": "STRPOS(x, 'a')", - "duckdb": "STRPOS(x, 'a')", - "postgres": "STRPOS(x, 'a')", - "presto": "STRPOS(x, 'a')", - "spark": "LOCATE('a', x)", - "clickhouse": "position(x, 'a')", - "snowflake": "POSITION('a', x)", - "mysql": "LOCATE('a', x)", + "drill": "STRPOS(haystack, needle)", + "duckdb": "STRPOS(haystack, needle)", + "postgres": "STRPOS(haystack, needle)", + "presto": "STRPOS(haystack, needle)", + "spark": "LOCATE(needle, haystack)", + "clickhouse": "position(haystack, needle)", + "snowflake": "POSITION(needle, haystack)", + "mysql": "LOCATE(needle, haystack)", }, ) self.validate_all( - "POSITION('a', x, 3)", + "POSITION(needle, haystack, pos)", write={ - "drill": "STRPOS(SUBSTR(x, 3), 'a') + 3 - 1", - "presto": "STRPOS(x, 'a', 3)", - "spark": "LOCATE('a', x, 3)", - "clickhouse": "position(x, 'a', 3)", - "snowflake": "POSITION('a', x, 3)", - "mysql": "LOCATE('a', x, 3)", + "drill": "STRPOS(SUBSTR(haystack, pos), needle) + pos - 1", + "presto": "STRPOS(haystack, needle, pos)", + "spark": "LOCATE(needle, haystack, pos)", + "clickhouse": "position(haystack, needle, pos)", + "snowflake": "POSITION(needle, haystack, pos)", + "mysql": "LOCATE(needle, haystack, pos)", }, ) self.validate_all( @@ -1365,3 +1365,19 @@ SELECT "spark": "MERGE INTO target USING source ON target.id = source.id WHEN MATCHED THEN UPDATE * WHEN NOT MATCHED THEN INSERT *", }, ) + self.validate_all( + """ + MERGE a b USING c d ON b.id = d.id + WHEN MATCHED AND EXISTS ( + SELECT b.name + EXCEPT + SELECT d.name + ) + THEN UPDATE SET b.name = d.name + """, + write={ + "bigquery": "MERGE INTO a AS b USING c AS d ON b.id = d.id WHEN MATCHED AND EXISTS(SELECT b.name EXCEPT DISTINCT SELECT d.name) THEN UPDATE SET b.name = d.name", + "snowflake": "MERGE INTO a AS b USING c AS d ON b.id = d.id WHEN MATCHED AND EXISTS(SELECT b.name EXCEPT SELECT d.name) THEN UPDATE SET b.name = d.name", + "spark": "MERGE INTO a AS b USING c AS d ON b.id = d.id WHEN MATCHED AND EXISTS(SELECT b.name EXCEPT SELECT d.name) THEN UPDATE SET b.name = d.name", + }, + ) diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py index dfd2f8e..ce865e1 100644 --- a/tests/dialects/test_mysql.py +++ b/tests/dialects/test_mysql.py @@ -75,6 +75,15 @@ class TestMySQL(Validator): "spark": "CAST(x AS TEXT) + CAST(y AS TEXT)", }, ) + self.validate_all( + "CAST(x AS MEDIUMBLOB) + CAST(y AS LONGBLOB)", + read={ + "mysql": "CAST(x AS MEDIUMBLOB) + CAST(y AS LONGBLOB)", + }, + write={ + "spark": "CAST(x AS BLOB) + CAST(y AS BLOB)", + }, + ) def test_canonical_functions(self): self.validate_identity("SELECT LEFT('str', 2)", "SELECT SUBSTRING('str', 1, 2)") diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py index 0e9ce9b..7bac166 100644 --- a/tests/dialects/test_snowflake.py +++ b/tests/dialects/test_snowflake.py @@ -13,6 +13,24 @@ class TestSnowflake(Validator): }, ) self.validate_all( + "SELECT * EXCLUDE a, b FROM xxx", + write={ + "snowflake": "SELECT * EXCLUDE (a, b) FROM xxx", + }, + ) + self.validate_all( + "SELECT * RENAME a AS b, c AS d FROM xxx", + write={ + "snowflake": "SELECT * RENAME (a AS b, c AS d) FROM xxx", + }, + ) + self.validate_all( + "SELECT * EXCLUDE a, b RENAME (c AS d, E as F) FROM xxx", + write={ + "snowflake": "SELECT * EXCLUDE (a, b) RENAME (c AS d, E AS F) FROM xxx", + }, + ) + self.validate_all( 'x:a:"b c"', write={ "duckdb": "x['a']['b c']", diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py index d2972ca..4224a1e 100644 --- a/tests/dialects/test_tsql.py +++ b/tests/dialects/test_tsql.py @@ -1,3 +1,4 @@ +from sqlglot import exp, parse, parse_one from tests.dialects.test_dialect import Validator @@ -5,6 +6,10 @@ class TestTSQL(Validator): dialect = "tsql" def test_tsql(self): + self.validate_identity("SELECT CASE WHEN a > 1 THEN b END") + self.validate_identity("END") + self.validate_identity("@x") + self.validate_identity("#x") self.validate_identity("DECLARE @TestVariable AS VARCHAR(100)='Save Our Planet'") self.validate_identity("PRINT @TestVariable") self.validate_identity("SELECT Employee_ID, Department_ID FROM @MyTableVar") @@ -87,6 +92,95 @@ class TestTSQL(Validator): }, ) + def test_udf(self): + self.validate_identity( + "CREATE PROCEDURE foo @a INTEGER, @b INTEGER AS SELECT @a = SUM(bla) FROM baz AS bar" + ) + self.validate_identity( + "CREATE PROC foo @ID INTEGER, @AGE INTEGER AS SELECT DB_NAME(@ID) AS ThatDB" + ) + self.validate_identity("CREATE PROC foo AS SELECT BAR() AS baz") + self.validate_identity("CREATE PROCEDURE foo AS SELECT BAR() AS baz") + self.validate_identity("CREATE FUNCTION foo(@bar INTEGER) RETURNS TABLE AS RETURN SELECT 1") + self.validate_identity("CREATE FUNCTION dbo.ISOweek(@DATE DATETIME2) RETURNS INTEGER") + + # The following two cases don't necessarily correspond to valid TSQL, but they are used to verify + # that the syntax RETURNS @return_variable TABLE <table_type_definition> ... is parsed correctly. + # + # See also "Transact-SQL Multi-Statement Table-Valued Function Syntax" + # https://learn.microsoft.com/en-us/sql/t-sql/statements/create-function-transact-sql?view=sql-server-ver16 + self.validate_identity( + "CREATE FUNCTION foo(@bar INTEGER) RETURNS @foo TABLE (x INTEGER, y NUMERIC) AS RETURN SELECT 1" + ) + self.validate_identity( + "CREATE FUNCTION foo() RETURNS @contacts TABLE (first_name VARCHAR(50), phone VARCHAR(25)) AS SELECT @fname, @phone" + ) + + self.validate_all( + """ + CREATE FUNCTION udfProductInYear ( + @model_year INT + ) + RETURNS TABLE + AS + RETURN + SELECT + product_name, + model_year, + list_price + FROM + production.products + WHERE + model_year = @model_year + """, + write={ + "tsql": """CREATE FUNCTION udfProductInYear( + @model_year INTEGER +) +RETURNS TABLE AS +RETURN SELECT + product_name, + model_year, + list_price +FROM production.products +WHERE + model_year = @model_year""", + }, + pretty=True, + ) + + sql = """ + CREATE procedure [TRANSF].[SP_Merge_Sales_Real] + @Loadid INTEGER + ,@NumberOfRows INTEGER + AS + BEGIN + SET XACT_ABORT ON; + + DECLARE @DWH_DateCreated DATETIME = CONVERT(DATETIME, getdate(), 104); + DECLARE @DWH_DateModified DATETIME = CONVERT(DATETIME, getdate(), 104); + DECLARE @DWH_IdUserCreated INTEGER = SUSER_ID (SYSTEM_USER); + DECLARE @DWH_IdUserModified INTEGER = SUSER_ID (SYSTEM_USER); + + DECLARE @SalesAmountBefore float; + SELECT @SalesAmountBefore=SUM(SalesAmount) FROM TRANSF.[Pre_Merge_Sales_Real] S; + END + """ + + expected_sqls = [ + 'CREATE PROCEDURE "TRANSF"."SP_Merge_Sales_Real" @Loadid INTEGER, @NumberOfRows INTEGER AS BEGIN SET XACT_ABORT ON', + "DECLARE @DWH_DateCreated DATETIME = CONVERT(DATETIME, getdate(), 104)", + "DECLARE @DWH_DateModified DATETIME = CONVERT(DATETIME, getdate(), 104)", + "DECLARE @DWH_IdUserCreated INTEGER = SUSER_ID (SYSTEM_USER)", + "DECLARE @DWH_IdUserModified INTEGER = SUSER_ID (SYSTEM_USER)", + "DECLARE @SalesAmountBefore float", + 'SELECT @SalesAmountBefore = SUM(SalesAmount) FROM TRANSF."Pre_Merge_Sales_Real" AS S', + "END", + ] + + for expr, expected_sql in zip(parse(sql, read="tsql"), expected_sqls): + self.assertEqual(expr.sql(dialect="tsql"), expected_sql) + def test_charindex(self): self.validate_all( "CHARINDEX(x, y, 9)", @@ -472,3 +566,51 @@ class TestTSQL(Validator): "EOMONTH(GETDATE(), -1)", write={"spark": "LAST_DAY(ADD_MONTHS(CURRENT_TIMESTAMP(), -1))"}, ) + + def test_variables(self): + # In TSQL @, # can be used as a prefix for variables/identifiers + expr = parse_one("@x", read="tsql") + self.assertIsInstance(expr, exp.Column) + self.assertIsInstance(expr.this, exp.Identifier) + + expr = parse_one("#x", read="tsql") + self.assertIsInstance(expr, exp.Column) + self.assertIsInstance(expr.this, exp.Identifier) + + def test_system_time(self): + self.validate_all( + "SELECT [x] FROM [a].[b] FOR SYSTEM_TIME AS OF 'foo'", + write={ + "tsql": """SELECT "x" FROM "a"."b" FOR SYSTEM_TIME AS OF 'foo'""", + }, + ) + self.validate_all( + "SELECT [x] FROM [a].[b] FOR SYSTEM_TIME AS OF 'foo' AS alias", + write={ + "tsql": """SELECT "x" FROM "a"."b" FOR SYSTEM_TIME AS OF 'foo' AS alias""", + }, + ) + self.validate_all( + "SELECT [x] FROM [a].[b] FOR SYSTEM_TIME FROM c TO d", + write={ + "tsql": """SELECT "x" FROM "a"."b" FOR SYSTEM_TIME FROM c TO d""", + }, + ) + self.validate_all( + "SELECT [x] FROM [a].[b] FOR SYSTEM_TIME BETWEEN c AND d", + write={ + "tsql": """SELECT "x" FROM "a"."b" FOR SYSTEM_TIME BETWEEN c AND d""", + }, + ) + self.validate_all( + "SELECT [x] FROM [a].[b] FOR SYSTEM_TIME CONTAINED IN (c, d)", + write={ + "tsql": """SELECT "x" FROM "a"."b" FOR SYSTEM_TIME CONTAINED IN (c, d)""", + }, + ) + self.validate_all( + "SELECT [x] FROM [a].[b] FOR SYSTEM_TIME ALL AS alias", + write={ + "tsql": """SELECT "x" FROM "a"."b" FOR SYSTEM_TIME ALL AS alias""", + }, + ) |