summaryrefslogtreecommitdiffstats
path: root/tests/dialects/test_tsql.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--tests/dialects/test_tsql.py67
1 files changed, 32 insertions, 35 deletions
diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py
index 92adf7a..7455650 100644
--- a/tests/dialects/test_tsql.py
+++ b/tests/dialects/test_tsql.py
@@ -1,12 +1,18 @@
-from sqlglot import exp, parse, parse_one
+from sqlglot import exp, parse
from tests.dialects.test_dialect import Validator
from sqlglot.errors import ParseError
+from sqlglot.optimizer.annotate_types import annotate_types
class TestTSQL(Validator):
dialect = "tsql"
def test_tsql(self):
+ self.assertEqual(
+ annotate_types(self.validate_identity("SELECT 1 WHERE EXISTS(SELECT 1)")).sql("tsql"),
+ "SELECT 1 WHERE EXISTS(SELECT 1)",
+ )
+
self.validate_identity("CREATE view a.b.c", "CREATE VIEW b.c")
self.validate_identity("DROP view a.b.c", "DROP VIEW b.c")
self.validate_identity("ROUND(x, 1, 0)")
@@ -217,9 +223,9 @@ class TestTSQL(Validator):
"CREATE TABLE [db].[tbl] ([a] INTEGER)",
)
- projection = parse_one("SELECT a = 1", read="tsql").selects[0]
- projection.assert_is(exp.Alias)
- projection.args["alias"].assert_is(exp.Identifier)
+ self.validate_identity("SELECT a = 1", "SELECT 1 AS a").selects[0].assert_is(
+ exp.Alias
+ ).args["alias"].assert_is(exp.Identifier)
self.validate_all(
"IF OBJECT_ID('tempdb.dbo.#TempTableName', 'U') IS NOT NULL DROP TABLE #TempTableName",
@@ -756,12 +762,9 @@ class TestTSQL(Validator):
for view_attr in ("ENCRYPTION", "SCHEMABINDING", "VIEW_METADATA"):
self.validate_identity(f"CREATE VIEW a.b WITH {view_attr} AS SELECT * FROM x")
- expression = parse_one("ALTER TABLE dbo.DocExe DROP CONSTRAINT FK_Column_B", dialect="tsql")
- self.assertIsInstance(expression, exp.AlterTable)
- self.assertIsInstance(expression.args["actions"][0], exp.Drop)
- self.assertEqual(
- expression.sql(dialect="tsql"), "ALTER TABLE dbo.DocExe DROP CONSTRAINT FK_Column_B"
- )
+ self.validate_identity("ALTER TABLE dbo.DocExe DROP CONSTRAINT FK_Column_B").assert_is(
+ exp.AlterTable
+ ).args["actions"][0].assert_is(exp.Drop)
for clustered_keyword in ("CLUSTERED", "NONCLUSTERED"):
self.validate_identity(
@@ -795,10 +798,10 @@ class TestTSQL(Validator):
)
self.validate_all(
- "CREATE TABLE [#temptest] (name VARCHAR)",
+ "CREATE TABLE [#temptest] (name INTEGER)",
read={
- "duckdb": "CREATE TEMPORARY TABLE 'temptest' (name VARCHAR)",
- "tsql": "CREATE TABLE [#temptest] (name VARCHAR)",
+ "duckdb": "CREATE TEMPORARY TABLE 'temptest' (name INTEGER)",
+ "tsql": "CREATE TABLE [#temptest] (name INTEGER)",
},
)
self.validate_all(
@@ -1632,27 +1635,23 @@ WHERE
)
def test_identifier_prefixes(self):
- expr = parse_one("#x", read="tsql")
- self.assertIsInstance(expr, exp.Column)
- self.assertIsInstance(expr.this, exp.Identifier)
- self.assertTrue(expr.this.args.get("temporary"))
- self.assertEqual(expr.sql("tsql"), "#x")
-
- expr = parse_one("##x", read="tsql")
- self.assertIsInstance(expr, exp.Column)
- self.assertIsInstance(expr.this, exp.Identifier)
- self.assertTrue(expr.this.args.get("global"))
- self.assertEqual(expr.sql("tsql"), "##x")
-
- expr = parse_one("@x", read="tsql")
- self.assertIsInstance(expr, exp.Parameter)
- self.assertIsInstance(expr.this, exp.Var)
- self.assertEqual(expr.sql("tsql"), "@x")
+ self.assertTrue(
+ self.validate_identity("#x")
+ .assert_is(exp.Column)
+ .this.assert_is(exp.Identifier)
+ .args.get("temporary")
+ )
+ self.assertTrue(
+ self.validate_identity("##x")
+ .assert_is(exp.Column)
+ .this.assert_is(exp.Identifier)
+ .args.get("global")
+ )
- table = parse_one("select * from @x", read="tsql").args["from"].this
- self.assertIsInstance(table, exp.Table)
- self.assertIsInstance(table.this, exp.Parameter)
- self.assertIsInstance(table.this.this, exp.Var)
+ self.validate_identity("@x").assert_is(exp.Parameter).this.assert_is(exp.Var)
+ self.validate_identity("SELECT * FROM @x").args["from"].this.assert_is(
+ exp.Table
+ ).this.assert_is(exp.Parameter).this.assert_is(exp.Var)
self.validate_all(
"SELECT @x",
@@ -1663,8 +1662,6 @@ WHERE
"tsql": "SELECT @x",
},
)
-
- def test_temp_table(self):
self.validate_all(
"SELECT * FROM #mytemptable",
write={