summaryrefslogtreecommitdiffstats
path: root/tests/dialects/test_tsql.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/dialects/test_tsql.py')
-rw-r--r--tests/dialects/test_tsql.py30
1 files changed, 20 insertions, 10 deletions
diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py
index 7cf9971..101d356 100644
--- a/tests/dialects/test_tsql.py
+++ b/tests/dialects/test_tsql.py
@@ -1,4 +1,5 @@
from sqlglot import exp, parse, parse_one
+from sqlglot.parser import logger as parser_logger
from tests.dialects.test_dialect import Validator
@@ -7,7 +8,7 @@ class TestTSQL(Validator):
def test_tsql(self):
self.validate_identity("ROUND(x, 1, 0)")
- self.validate_identity("EXEC MyProc @id=7, @name='Lochristi'")
+ self.validate_identity("EXEC MyProc @id=7, @name='Lochristi'", check_command_warning=True)
# https://learn.microsoft.com/en-us/previous-versions/sql/sql-server-2008-r2/ms187879(v=sql.105)?redirectedfrom=MSDN
# tsql allows .. which means use the default schema
self.validate_identity("SELECT * FROM a..b")
@@ -225,7 +226,7 @@ class TestTSQL(Validator):
"MERGE INTO mytable WITH (HOLDLOCK) AS T USING mytable_merge AS S "
"ON (T.user_id = S.user_id) WHEN NOT MATCHED THEN INSERT (c1, c2) VALUES (S.c1, S.c2)"
)
- self.validate_identity("UPDATE STATISTICS x")
+ self.validate_identity("UPDATE STATISTICS x", check_command_warning=True)
self.validate_identity("UPDATE x SET y = 1 OUTPUT x.a, x.b INTO @y FROM y")
self.validate_identity("UPDATE x SET y = 1 OUTPUT x.a, x.b FROM y")
self.validate_identity("INSERT INTO x (y) OUTPUT x.a, x.b INTO l SELECT * FROM z")
@@ -238,14 +239,16 @@ class TestTSQL(Validator):
self.validate_identity("END")
self.validate_identity("@x")
self.validate_identity("#x")
- self.validate_identity("DECLARE @TestVariable AS VARCHAR(100)='Save Our Planet'")
- self.validate_identity("PRINT @TestVariable")
+ self.validate_identity("PRINT @TestVariable", check_command_warning=True)
self.validate_identity("SELECT Employee_ID, Department_ID FROM @MyTableVar")
self.validate_identity("INSERT INTO @TestTable VALUES (1, 'Value1', 12, 20)")
self.validate_identity("SELECT * FROM #foo")
self.validate_identity("SELECT * FROM ##foo")
self.validate_identity("SELECT a = 1", "SELECT 1 AS a")
self.validate_identity(
+ "DECLARE @TestVariable AS VARCHAR(100)='Save Our Planet'", check_command_warning=True
+ )
+ self.validate_identity(
"SELECT a = 1 UNION ALL SELECT a = b", "SELECT 1 AS a UNION ALL SELECT b AS a"
)
self.validate_identity(
@@ -789,7 +792,8 @@ class TestTSQL(Validator):
def test_udf(self):
self.validate_identity(
- "DECLARE @DWH_DateCreated DATETIME = CONVERT(DATETIME, getdate(), 104)"
+ "DECLARE @DWH_DateCreated DATETIME = CONVERT(DATETIME, getdate(), 104)",
+ check_command_warning=True,
)
self.validate_identity(
"CREATE PROCEDURE foo @a INTEGER, @b INTEGER AS SELECT @a = SUM(bla) FROM baz AS bar"
@@ -882,8 +886,9 @@ WHERE
"END",
]
- for expr, expected_sql in zip(parse(sql, read="tsql"), expected_sqls):
- self.assertEqual(expr.sql(dialect="tsql"), expected_sql)
+ with self.assertLogs(parser_logger) as cm:
+ for expr, expected_sql in zip(parse(sql, read="tsql"), expected_sqls):
+ self.assertEqual(expr.sql(dialect="tsql"), expected_sql)
sql = """
CREATE PROC [dbo].[transform_proc] AS
@@ -902,8 +907,9 @@ WHERE
"CREATE TABLE [target_schema].[target_table] (a INTEGER) WITH (DISTRIBUTION=REPLICATE, HEAP)",
]
- for expr, expected_sql in zip(parse(sql, read="tsql"), expected_sqls):
- self.assertEqual(expr.sql(dialect="tsql"), expected_sql)
+ with self.assertLogs(parser_logger) as cm:
+ for expr, expected_sql in zip(parse(sql, read="tsql"), expected_sqls):
+ self.assertEqual(expr.sql(dialect="tsql"), expected_sql)
def test_charindex(self):
self.validate_all(
@@ -932,7 +938,11 @@ WHERE
)
def test_len(self):
- self.validate_all("LEN(x)", read={"": "LENGTH(x)"}, write={"spark": "LENGTH(x)"})
+ self.validate_all(
+ "LEN(x)", read={"": "LENGTH(x)"}, write={"spark": "LENGTH(CAST(x AS STRING))"}
+ )
+ self.validate_all("LEN(1)", write={"tsql": "LEN(1)", "spark": "LENGTH(CAST(1 AS STRING))"})
+ self.validate_all("LEN('x')", write={"tsql": "LEN('x')", "spark": "LENGTH('x')"})
def test_replicate(self):
self.validate_all("REPLICATE('x', 2)", write={"spark": "REPEAT('x', 2)"})