From 376de8b6892deca7dc5d83035c047f1e13eb67ea Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Wed, 31 Jan 2024 06:44:41 +0100 Subject: Merging upstream version 20.11.0. Signed-off-by: Daniel Baumann --- tests/dialects/test_tsql.py | 30 ++++++++++++++++++++---------- 1 file changed, 20 insertions(+), 10 deletions(-) (limited to 'tests/dialects/test_tsql.py') diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py index 7cf9971..101d356 100644 --- a/tests/dialects/test_tsql.py +++ b/tests/dialects/test_tsql.py @@ -1,4 +1,5 @@ from sqlglot import exp, parse, parse_one +from sqlglot.parser import logger as parser_logger from tests.dialects.test_dialect import Validator @@ -7,7 +8,7 @@ class TestTSQL(Validator): def test_tsql(self): self.validate_identity("ROUND(x, 1, 0)") - self.validate_identity("EXEC MyProc @id=7, @name='Lochristi'") + self.validate_identity("EXEC MyProc @id=7, @name='Lochristi'", check_command_warning=True) # https://learn.microsoft.com/en-us/previous-versions/sql/sql-server-2008-r2/ms187879(v=sql.105)?redirectedfrom=MSDN # tsql allows .. which means use the default schema self.validate_identity("SELECT * FROM a..b") @@ -225,7 +226,7 @@ class TestTSQL(Validator): "MERGE INTO mytable WITH (HOLDLOCK) AS T USING mytable_merge AS S " "ON (T.user_id = S.user_id) WHEN NOT MATCHED THEN INSERT (c1, c2) VALUES (S.c1, S.c2)" ) - self.validate_identity("UPDATE STATISTICS x") + self.validate_identity("UPDATE STATISTICS x", check_command_warning=True) self.validate_identity("UPDATE x SET y = 1 OUTPUT x.a, x.b INTO @y FROM y") self.validate_identity("UPDATE x SET y = 1 OUTPUT x.a, x.b FROM y") self.validate_identity("INSERT INTO x (y) OUTPUT x.a, x.b INTO l SELECT * FROM z") @@ -238,13 +239,15 @@ class TestTSQL(Validator): self.validate_identity("END") self.validate_identity("@x") self.validate_identity("#x") - self.validate_identity("DECLARE @TestVariable AS VARCHAR(100)='Save Our Planet'") - self.validate_identity("PRINT @TestVariable") + self.validate_identity("PRINT @TestVariable", check_command_warning=True) self.validate_identity("SELECT Employee_ID, Department_ID FROM @MyTableVar") self.validate_identity("INSERT INTO @TestTable VALUES (1, 'Value1', 12, 20)") self.validate_identity("SELECT * FROM #foo") self.validate_identity("SELECT * FROM ##foo") self.validate_identity("SELECT a = 1", "SELECT 1 AS a") + self.validate_identity( + "DECLARE @TestVariable AS VARCHAR(100)='Save Our Planet'", check_command_warning=True + ) self.validate_identity( "SELECT a = 1 UNION ALL SELECT a = b", "SELECT 1 AS a UNION ALL SELECT b AS a" ) @@ -789,7 +792,8 @@ class TestTSQL(Validator): def test_udf(self): self.validate_identity( - "DECLARE @DWH_DateCreated DATETIME = CONVERT(DATETIME, getdate(), 104)" + "DECLARE @DWH_DateCreated DATETIME = CONVERT(DATETIME, getdate(), 104)", + check_command_warning=True, ) self.validate_identity( "CREATE PROCEDURE foo @a INTEGER, @b INTEGER AS SELECT @a = SUM(bla) FROM baz AS bar" @@ -882,8 +886,9 @@ WHERE "END", ] - for expr, expected_sql in zip(parse(sql, read="tsql"), expected_sqls): - self.assertEqual(expr.sql(dialect="tsql"), expected_sql) + with self.assertLogs(parser_logger) as cm: + for expr, expected_sql in zip(parse(sql, read="tsql"), expected_sqls): + self.assertEqual(expr.sql(dialect="tsql"), expected_sql) sql = """ CREATE PROC [dbo].[transform_proc] AS @@ -902,8 +907,9 @@ WHERE "CREATE TABLE [target_schema].[target_table] (a INTEGER) WITH (DISTRIBUTION=REPLICATE, HEAP)", ] - for expr, expected_sql in zip(parse(sql, read="tsql"), expected_sqls): - self.assertEqual(expr.sql(dialect="tsql"), expected_sql) + with self.assertLogs(parser_logger) as cm: + for expr, expected_sql in zip(parse(sql, read="tsql"), expected_sqls): + self.assertEqual(expr.sql(dialect="tsql"), expected_sql) def test_charindex(self): self.validate_all( @@ -932,7 +938,11 @@ WHERE ) def test_len(self): - self.validate_all("LEN(x)", read={"": "LENGTH(x)"}, write={"spark": "LENGTH(x)"}) + self.validate_all( + "LEN(x)", read={"": "LENGTH(x)"}, write={"spark": "LENGTH(CAST(x AS STRING))"} + ) + self.validate_all("LEN(1)", write={"tsql": "LEN(1)", "spark": "LENGTH(CAST(1 AS STRING))"}) + self.validate_all("LEN('x')", write={"tsql": "LEN('x')", "spark": "LENGTH('x')"}) def test_replicate(self): self.validate_all("REPLICATE('x', 2)", write={"spark": "REPEAT('x', 2)"}) -- cgit v1.2.3