From be1cb18ea28222fca384a5459a024b7e9af5cadb Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Mon, 30 Jan 2023 18:08:37 +0100 Subject: Merging upstream version 10.5.10. Signed-off-by: Daniel Baumann --- tests/dialects/test_tsql.py | 142 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 142 insertions(+) (limited to 'tests/dialects/test_tsql.py') diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py index d2972ca..4224a1e 100644 --- a/tests/dialects/test_tsql.py +++ b/tests/dialects/test_tsql.py @@ -1,3 +1,4 @@ +from sqlglot import exp, parse, parse_one from tests.dialects.test_dialect import Validator @@ -5,6 +6,10 @@ class TestTSQL(Validator): dialect = "tsql" def test_tsql(self): + self.validate_identity("SELECT CASE WHEN a > 1 THEN b END") + self.validate_identity("END") + self.validate_identity("@x") + self.validate_identity("#x") self.validate_identity("DECLARE @TestVariable AS VARCHAR(100)='Save Our Planet'") self.validate_identity("PRINT @TestVariable") self.validate_identity("SELECT Employee_ID, Department_ID FROM @MyTableVar") @@ -87,6 +92,95 @@ class TestTSQL(Validator): }, ) + def test_udf(self): + self.validate_identity( + "CREATE PROCEDURE foo @a INTEGER, @b INTEGER AS SELECT @a = SUM(bla) FROM baz AS bar" + ) + self.validate_identity( + "CREATE PROC foo @ID INTEGER, @AGE INTEGER AS SELECT DB_NAME(@ID) AS ThatDB" + ) + self.validate_identity("CREATE PROC foo AS SELECT BAR() AS baz") + self.validate_identity("CREATE PROCEDURE foo AS SELECT BAR() AS baz") + self.validate_identity("CREATE FUNCTION foo(@bar INTEGER) RETURNS TABLE AS RETURN SELECT 1") + self.validate_identity("CREATE FUNCTION dbo.ISOweek(@DATE DATETIME2) RETURNS INTEGER") + + # The following two cases don't necessarily correspond to valid TSQL, but they are used to verify + # that the syntax RETURNS @return_variable TABLE ... is parsed correctly. + # + # See also "Transact-SQL Multi-Statement Table-Valued Function Syntax" + # https://learn.microsoft.com/en-us/sql/t-sql/statements/create-function-transact-sql?view=sql-server-ver16 + self.validate_identity( + "CREATE FUNCTION foo(@bar INTEGER) RETURNS @foo TABLE (x INTEGER, y NUMERIC) AS RETURN SELECT 1" + ) + self.validate_identity( + "CREATE FUNCTION foo() RETURNS @contacts TABLE (first_name VARCHAR(50), phone VARCHAR(25)) AS SELECT @fname, @phone" + ) + + self.validate_all( + """ + CREATE FUNCTION udfProductInYear ( + @model_year INT + ) + RETURNS TABLE + AS + RETURN + SELECT + product_name, + model_year, + list_price + FROM + production.products + WHERE + model_year = @model_year + """, + write={ + "tsql": """CREATE FUNCTION udfProductInYear( + @model_year INTEGER +) +RETURNS TABLE AS +RETURN SELECT + product_name, + model_year, + list_price +FROM production.products +WHERE + model_year = @model_year""", + }, + pretty=True, + ) + + sql = """ + CREATE procedure [TRANSF].[SP_Merge_Sales_Real] + @Loadid INTEGER + ,@NumberOfRows INTEGER + AS + BEGIN + SET XACT_ABORT ON; + + DECLARE @DWH_DateCreated DATETIME = CONVERT(DATETIME, getdate(), 104); + DECLARE @DWH_DateModified DATETIME = CONVERT(DATETIME, getdate(), 104); + DECLARE @DWH_IdUserCreated INTEGER = SUSER_ID (SYSTEM_USER); + DECLARE @DWH_IdUserModified INTEGER = SUSER_ID (SYSTEM_USER); + + DECLARE @SalesAmountBefore float; + SELECT @SalesAmountBefore=SUM(SalesAmount) FROM TRANSF.[Pre_Merge_Sales_Real] S; + END + """ + + expected_sqls = [ + 'CREATE PROCEDURE "TRANSF"."SP_Merge_Sales_Real" @Loadid INTEGER, @NumberOfRows INTEGER AS BEGIN SET XACT_ABORT ON', + "DECLARE @DWH_DateCreated DATETIME = CONVERT(DATETIME, getdate(), 104)", + "DECLARE @DWH_DateModified DATETIME = CONVERT(DATETIME, getdate(), 104)", + "DECLARE @DWH_IdUserCreated INTEGER = SUSER_ID (SYSTEM_USER)", + "DECLARE @DWH_IdUserModified INTEGER = SUSER_ID (SYSTEM_USER)", + "DECLARE @SalesAmountBefore float", + 'SELECT @SalesAmountBefore = SUM(SalesAmount) FROM TRANSF."Pre_Merge_Sales_Real" AS S', + "END", + ] + + for expr, expected_sql in zip(parse(sql, read="tsql"), expected_sqls): + self.assertEqual(expr.sql(dialect="tsql"), expected_sql) + def test_charindex(self): self.validate_all( "CHARINDEX(x, y, 9)", @@ -472,3 +566,51 @@ class TestTSQL(Validator): "EOMONTH(GETDATE(), -1)", write={"spark": "LAST_DAY(ADD_MONTHS(CURRENT_TIMESTAMP(), -1))"}, ) + + def test_variables(self): + # In TSQL @, # can be used as a prefix for variables/identifiers + expr = parse_one("@x", read="tsql") + self.assertIsInstance(expr, exp.Column) + self.assertIsInstance(expr.this, exp.Identifier) + + expr = parse_one("#x", read="tsql") + self.assertIsInstance(expr, exp.Column) + self.assertIsInstance(expr.this, exp.Identifier) + + def test_system_time(self): + self.validate_all( + "SELECT [x] FROM [a].[b] FOR SYSTEM_TIME AS OF 'foo'", + write={ + "tsql": """SELECT "x" FROM "a"."b" FOR SYSTEM_TIME AS OF 'foo'""", + }, + ) + self.validate_all( + "SELECT [x] FROM [a].[b] FOR SYSTEM_TIME AS OF 'foo' AS alias", + write={ + "tsql": """SELECT "x" FROM "a"."b" FOR SYSTEM_TIME AS OF 'foo' AS alias""", + }, + ) + self.validate_all( + "SELECT [x] FROM [a].[b] FOR SYSTEM_TIME FROM c TO d", + write={ + "tsql": """SELECT "x" FROM "a"."b" FOR SYSTEM_TIME FROM c TO d""", + }, + ) + self.validate_all( + "SELECT [x] FROM [a].[b] FOR SYSTEM_TIME BETWEEN c AND d", + write={ + "tsql": """SELECT "x" FROM "a"."b" FOR SYSTEM_TIME BETWEEN c AND d""", + }, + ) + self.validate_all( + "SELECT [x] FROM [a].[b] FOR SYSTEM_TIME CONTAINED IN (c, d)", + write={ + "tsql": """SELECT "x" FROM "a"."b" FOR SYSTEM_TIME CONTAINED IN (c, d)""", + }, + ) + self.validate_all( + "SELECT [x] FROM [a].[b] FOR SYSTEM_TIME ALL AS alias", + write={ + "tsql": """SELECT "x" FROM "a"."b" FOR SYSTEM_TIME ALL AS alias""", + }, + ) -- cgit v1.2.3