summaryrefslogtreecommitdiffstats
path: root/tests/dialects/test_tsql.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/dialects/test_tsql.py')
-rw-r--r--tests/dialects/test_tsql.py142
1 files changed, 142 insertions, 0 deletions
diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py
index d2972ca..4224a1e 100644
--- a/tests/dialects/test_tsql.py
+++ b/tests/dialects/test_tsql.py
@@ -1,3 +1,4 @@
+from sqlglot import exp, parse, parse_one
from tests.dialects.test_dialect import Validator
@@ -5,6 +6,10 @@ class TestTSQL(Validator):
dialect = "tsql"
def test_tsql(self):
+ self.validate_identity("SELECT CASE WHEN a > 1 THEN b END")
+ self.validate_identity("END")
+ self.validate_identity("@x")
+ self.validate_identity("#x")
self.validate_identity("DECLARE @TestVariable AS VARCHAR(100)='Save Our Planet'")
self.validate_identity("PRINT @TestVariable")
self.validate_identity("SELECT Employee_ID, Department_ID FROM @MyTableVar")
@@ -87,6 +92,95 @@ class TestTSQL(Validator):
},
)
+ def test_udf(self):
+ self.validate_identity(
+ "CREATE PROCEDURE foo @a INTEGER, @b INTEGER AS SELECT @a = SUM(bla) FROM baz AS bar"
+ )
+ self.validate_identity(
+ "CREATE PROC foo @ID INTEGER, @AGE INTEGER AS SELECT DB_NAME(@ID) AS ThatDB"
+ )
+ self.validate_identity("CREATE PROC foo AS SELECT BAR() AS baz")
+ self.validate_identity("CREATE PROCEDURE foo AS SELECT BAR() AS baz")
+ self.validate_identity("CREATE FUNCTION foo(@bar INTEGER) RETURNS TABLE AS RETURN SELECT 1")
+ self.validate_identity("CREATE FUNCTION dbo.ISOweek(@DATE DATETIME2) RETURNS INTEGER")
+
+ # The following two cases don't necessarily correspond to valid TSQL, but they are used to verify
+ # that the syntax RETURNS @return_variable TABLE <table_type_definition> ... is parsed correctly.
+ #
+ # See also "Transact-SQL Multi-Statement Table-Valued Function Syntax"
+ # https://learn.microsoft.com/en-us/sql/t-sql/statements/create-function-transact-sql?view=sql-server-ver16
+ self.validate_identity(
+ "CREATE FUNCTION foo(@bar INTEGER) RETURNS @foo TABLE (x INTEGER, y NUMERIC) AS RETURN SELECT 1"
+ )
+ self.validate_identity(
+ "CREATE FUNCTION foo() RETURNS @contacts TABLE (first_name VARCHAR(50), phone VARCHAR(25)) AS SELECT @fname, @phone"
+ )
+
+ self.validate_all(
+ """
+ CREATE FUNCTION udfProductInYear (
+ @model_year INT
+ )
+ RETURNS TABLE
+ AS
+ RETURN
+ SELECT
+ product_name,
+ model_year,
+ list_price
+ FROM
+ production.products
+ WHERE
+ model_year = @model_year
+ """,
+ write={
+ "tsql": """CREATE FUNCTION udfProductInYear(
+ @model_year INTEGER
+)
+RETURNS TABLE AS
+RETURN SELECT
+ product_name,
+ model_year,
+ list_price
+FROM production.products
+WHERE
+ model_year = @model_year""",
+ },
+ pretty=True,
+ )
+
+ sql = """
+ CREATE procedure [TRANSF].[SP_Merge_Sales_Real]
+ @Loadid INTEGER
+ ,@NumberOfRows INTEGER
+ AS
+ BEGIN
+ SET XACT_ABORT ON;
+
+ DECLARE @DWH_DateCreated DATETIME = CONVERT(DATETIME, getdate(), 104);
+ DECLARE @DWH_DateModified DATETIME = CONVERT(DATETIME, getdate(), 104);
+ DECLARE @DWH_IdUserCreated INTEGER = SUSER_ID (SYSTEM_USER);
+ DECLARE @DWH_IdUserModified INTEGER = SUSER_ID (SYSTEM_USER);
+
+ DECLARE @SalesAmountBefore float;
+ SELECT @SalesAmountBefore=SUM(SalesAmount) FROM TRANSF.[Pre_Merge_Sales_Real] S;
+ END
+ """
+
+ expected_sqls = [
+ 'CREATE PROCEDURE "TRANSF"."SP_Merge_Sales_Real" @Loadid INTEGER, @NumberOfRows INTEGER AS BEGIN SET XACT_ABORT ON',
+ "DECLARE @DWH_DateCreated DATETIME = CONVERT(DATETIME, getdate(), 104)",
+ "DECLARE @DWH_DateModified DATETIME = CONVERT(DATETIME, getdate(), 104)",
+ "DECLARE @DWH_IdUserCreated INTEGER = SUSER_ID (SYSTEM_USER)",
+ "DECLARE @DWH_IdUserModified INTEGER = SUSER_ID (SYSTEM_USER)",
+ "DECLARE @SalesAmountBefore float",
+ 'SELECT @SalesAmountBefore = SUM(SalesAmount) FROM TRANSF."Pre_Merge_Sales_Real" AS S',
+ "END",
+ ]
+
+ for expr, expected_sql in zip(parse(sql, read="tsql"), expected_sqls):
+ self.assertEqual(expr.sql(dialect="tsql"), expected_sql)
+
def test_charindex(self):
self.validate_all(
"CHARINDEX(x, y, 9)",
@@ -472,3 +566,51 @@ class TestTSQL(Validator):
"EOMONTH(GETDATE(), -1)",
write={"spark": "LAST_DAY(ADD_MONTHS(CURRENT_TIMESTAMP(), -1))"},
)
+
+ def test_variables(self):
+ # In TSQL @, # can be used as a prefix for variables/identifiers
+ expr = parse_one("@x", read="tsql")
+ self.assertIsInstance(expr, exp.Column)
+ self.assertIsInstance(expr.this, exp.Identifier)
+
+ expr = parse_one("#x", read="tsql")
+ self.assertIsInstance(expr, exp.Column)
+ self.assertIsInstance(expr.this, exp.Identifier)
+
+ def test_system_time(self):
+ self.validate_all(
+ "SELECT [x] FROM [a].[b] FOR SYSTEM_TIME AS OF 'foo'",
+ write={
+ "tsql": """SELECT "x" FROM "a"."b" FOR SYSTEM_TIME AS OF 'foo'""",
+ },
+ )
+ self.validate_all(
+ "SELECT [x] FROM [a].[b] FOR SYSTEM_TIME AS OF 'foo' AS alias",
+ write={
+ "tsql": """SELECT "x" FROM "a"."b" FOR SYSTEM_TIME AS OF 'foo' AS alias""",
+ },
+ )
+ self.validate_all(
+ "SELECT [x] FROM [a].[b] FOR SYSTEM_TIME FROM c TO d",
+ write={
+ "tsql": """SELECT "x" FROM "a"."b" FOR SYSTEM_TIME FROM c TO d""",
+ },
+ )
+ self.validate_all(
+ "SELECT [x] FROM [a].[b] FOR SYSTEM_TIME BETWEEN c AND d",
+ write={
+ "tsql": """SELECT "x" FROM "a"."b" FOR SYSTEM_TIME BETWEEN c AND d""",
+ },
+ )
+ self.validate_all(
+ "SELECT [x] FROM [a].[b] FOR SYSTEM_TIME CONTAINED IN (c, d)",
+ write={
+ "tsql": """SELECT "x" FROM "a"."b" FOR SYSTEM_TIME CONTAINED IN (c, d)""",
+ },
+ )
+ self.validate_all(
+ "SELECT [x] FROM [a].[b] FOR SYSTEM_TIME ALL AS alias",
+ write={
+ "tsql": """SELECT "x" FROM "a"."b" FOR SYSTEM_TIME ALL AS alias""",
+ },
+ )