diff options
Diffstat (limited to 'tests/dialects/test_tsql.py')
-rw-r--r-- | tests/dialects/test_tsql.py | 67 |
1 files changed, 32 insertions, 35 deletions
diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py index 92adf7a..7455650 100644 --- a/tests/dialects/test_tsql.py +++ b/tests/dialects/test_tsql.py @@ -1,12 +1,18 @@ -from sqlglot import exp, parse, parse_one +from sqlglot import exp, parse from tests.dialects.test_dialect import Validator from sqlglot.errors import ParseError +from sqlglot.optimizer.annotate_types import annotate_types class TestTSQL(Validator): dialect = "tsql" def test_tsql(self): + self.assertEqual( + annotate_types(self.validate_identity("SELECT 1 WHERE EXISTS(SELECT 1)")).sql("tsql"), + "SELECT 1 WHERE EXISTS(SELECT 1)", + ) + self.validate_identity("CREATE view a.b.c", "CREATE VIEW b.c") self.validate_identity("DROP view a.b.c", "DROP VIEW b.c") self.validate_identity("ROUND(x, 1, 0)") @@ -217,9 +223,9 @@ class TestTSQL(Validator): "CREATE TABLE [db].[tbl] ([a] INTEGER)", ) - projection = parse_one("SELECT a = 1", read="tsql").selects[0] - projection.assert_is(exp.Alias) - projection.args["alias"].assert_is(exp.Identifier) + self.validate_identity("SELECT a = 1", "SELECT 1 AS a").selects[0].assert_is( + exp.Alias + ).args["alias"].assert_is(exp.Identifier) self.validate_all( "IF OBJECT_ID('tempdb.dbo.#TempTableName', 'U') IS NOT NULL DROP TABLE #TempTableName", @@ -756,12 +762,9 @@ class TestTSQL(Validator): for view_attr in ("ENCRYPTION", "SCHEMABINDING", "VIEW_METADATA"): self.validate_identity(f"CREATE VIEW a.b WITH {view_attr} AS SELECT * FROM x") - expression = parse_one("ALTER TABLE dbo.DocExe DROP CONSTRAINT FK_Column_B", dialect="tsql") - self.assertIsInstance(expression, exp.AlterTable) - self.assertIsInstance(expression.args["actions"][0], exp.Drop) - self.assertEqual( - expression.sql(dialect="tsql"), "ALTER TABLE dbo.DocExe DROP CONSTRAINT FK_Column_B" - ) + self.validate_identity("ALTER TABLE dbo.DocExe DROP CONSTRAINT FK_Column_B").assert_is( + exp.AlterTable + ).args["actions"][0].assert_is(exp.Drop) for clustered_keyword in ("CLUSTERED", "NONCLUSTERED"): self.validate_identity( @@ -795,10 +798,10 @@ class TestTSQL(Validator): ) self.validate_all( - "CREATE TABLE [#temptest] (name VARCHAR)", + "CREATE TABLE [#temptest] (name INTEGER)", read={ - "duckdb": "CREATE TEMPORARY TABLE 'temptest' (name VARCHAR)", - "tsql": "CREATE TABLE [#temptest] (name VARCHAR)", + "duckdb": "CREATE TEMPORARY TABLE 'temptest' (name INTEGER)", + "tsql": "CREATE TABLE [#temptest] (name INTEGER)", }, ) self.validate_all( @@ -1632,27 +1635,23 @@ WHERE ) def test_identifier_prefixes(self): - expr = parse_one("#x", read="tsql") - self.assertIsInstance(expr, exp.Column) - self.assertIsInstance(expr.this, exp.Identifier) - self.assertTrue(expr.this.args.get("temporary")) - self.assertEqual(expr.sql("tsql"), "#x") - - expr = parse_one("##x", read="tsql") - self.assertIsInstance(expr, exp.Column) - self.assertIsInstance(expr.this, exp.Identifier) - self.assertTrue(expr.this.args.get("global")) - self.assertEqual(expr.sql("tsql"), "##x") - - expr = parse_one("@x", read="tsql") - self.assertIsInstance(expr, exp.Parameter) - self.assertIsInstance(expr.this, exp.Var) - self.assertEqual(expr.sql("tsql"), "@x") + self.assertTrue( + self.validate_identity("#x") + .assert_is(exp.Column) + .this.assert_is(exp.Identifier) + .args.get("temporary") + ) + self.assertTrue( + self.validate_identity("##x") + .assert_is(exp.Column) + .this.assert_is(exp.Identifier) + .args.get("global") + ) - table = parse_one("select * from @x", read="tsql").args["from"].this - self.assertIsInstance(table, exp.Table) - self.assertIsInstance(table.this, exp.Parameter) - self.assertIsInstance(table.this.this, exp.Var) + self.validate_identity("@x").assert_is(exp.Parameter).this.assert_is(exp.Var) + self.validate_identity("SELECT * FROM @x").args["from"].this.assert_is( + exp.Table + ).this.assert_is(exp.Parameter).this.assert_is(exp.Var) self.validate_all( "SELECT @x", @@ -1663,8 +1662,6 @@ WHERE "tsql": "SELECT @x", }, ) - - def test_temp_table(self): self.validate_all( "SELECT * FROM #mytemptable", write={ |