summaryrefslogtreecommitdiffstats
path: root/tests/dialects
diff options
context:
space:
mode:
Diffstat (limited to 'tests/dialects')
-rw-r--r--tests/dialects/test_bigquery.py15
-rw-r--r--tests/dialects/test_duckdb.py16
-rw-r--r--tests/dialects/test_mysql.py11
-rw-r--r--tests/dialects/test_oracle.py70
-rw-r--r--tests/dialects/test_postgres.py23
-rw-r--r--tests/dialects/test_presto.py2
-rw-r--r--tests/dialects/test_redshift.py9
-rw-r--r--tests/dialects/test_snowflake.py12
-rw-r--r--tests/dialects/test_spark.py11
-rw-r--r--tests/dialects/test_sqlite.py1
-rw-r--r--tests/dialects/test_tsql.py67
11 files changed, 190 insertions, 47 deletions
diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py
index bfaf009..ae8ed16 100644
--- a/tests/dialects/test_bigquery.py
+++ b/tests/dialects/test_bigquery.py
@@ -20,6 +20,14 @@ class TestBigQuery(Validator):
maxDiff = None
def test_bigquery(self):
+ self.validate_all(
+ "EXTRACT(HOUR FROM DATETIME(2008, 12, 25, 15, 30, 00))",
+ write={
+ "bigquery": "EXTRACT(HOUR FROM DATETIME(2008, 12, 25, 15, 30, 00))",
+ "duckdb": "EXTRACT(HOUR FROM MAKE_TIMESTAMP(2008, 12, 25, 15, 30, 00))",
+ "snowflake": "DATE_PART(HOUR, TIMESTAMP_FROM_PARTS(2008, 12, 25, 15, 30, 00))",
+ },
+ )
self.validate_identity(
"""CREATE TEMPORARY FUNCTION FOO()
RETURNS STRING
@@ -619,9 +627,9 @@ LANGUAGE js AS
'SELECT TIMESTAMP_ADD(TIMESTAMP "2008-12-25 15:30:00+00", INTERVAL 10 MINUTE)',
write={
"bigquery": "SELECT TIMESTAMP_ADD(CAST('2008-12-25 15:30:00+00' AS TIMESTAMP), INTERVAL 10 MINUTE)",
- "databricks": "SELECT DATEADD(MINUTE, 10, CAST('2008-12-25 15:30:00+00' AS TIMESTAMP))",
+ "databricks": "SELECT DATE_ADD(MINUTE, 10, CAST('2008-12-25 15:30:00+00' AS TIMESTAMP))",
"mysql": "SELECT DATE_ADD(TIMESTAMP('2008-12-25 15:30:00+00'), INTERVAL 10 MINUTE)",
- "spark": "SELECT DATEADD(MINUTE, 10, CAST('2008-12-25 15:30:00+00' AS TIMESTAMP))",
+ "spark": "SELECT DATE_ADD(MINUTE, 10, CAST('2008-12-25 15:30:00+00' AS TIMESTAMP))",
},
)
self.validate_all(
@@ -761,12 +769,15 @@ LANGUAGE js AS
"clickhouse": "SHA256(x)",
"presto": "SHA256(x)",
"trino": "SHA256(x)",
+ "postgres": "SHA256(x)",
},
write={
"bigquery": "SHA256(x)",
"spark2": "SHA2(x, 256)",
"clickhouse": "SHA256(x)",
+ "postgres": "SHA256(x)",
"presto": "SHA256(x)",
+ "redshift": "SHA2(x, 256)",
"trino": "SHA256(x)",
},
)
diff --git a/tests/dialects/test_duckdb.py b/tests/dialects/test_duckdb.py
index cd68ff9..2bde478 100644
--- a/tests/dialects/test_duckdb.py
+++ b/tests/dialects/test_duckdb.py
@@ -19,6 +19,13 @@ class TestDuckDB(Validator):
)
self.validate_all(
+ "SELECT straight_join",
+ write={
+ "duckdb": "SELECT straight_join",
+ "mysql": "SELECT `straight_join`",
+ },
+ )
+ self.validate_all(
"SELECT CAST('2020-01-01 12:05:01' AS TIMESTAMP)",
read={
"duckdb": "SELECT CAST('2020-01-01 12:05:01' AS TIMESTAMP)",
@@ -278,6 +285,7 @@ class TestDuckDB(Validator):
self.validate_identity("FROM tbl", "SELECT * FROM tbl")
self.validate_identity("x -> '$.family'")
self.validate_identity("CREATE TABLE color (name ENUM('RED', 'GREEN', 'BLUE'))")
+ self.validate_identity("SELECT * FROM foo WHERE bar > $baz AND bla = $bob")
self.validate_identity(
"SELECT * FROM x LEFT JOIN UNNEST(y)", "SELECT * FROM x LEFT JOIN UNNEST(y) ON TRUE"
)
@@ -1000,6 +1008,7 @@ class TestDuckDB(Validator):
self.validate_identity("CAST(x AS CHAR)", "CAST(x AS TEXT)")
self.validate_identity("CAST(x AS BPCHAR)", "CAST(x AS TEXT)")
self.validate_identity("CAST(x AS STRING)", "CAST(x AS TEXT)")
+ self.validate_identity("CAST(x AS VARCHAR)", "CAST(x AS TEXT)")
self.validate_identity("CAST(x AS INT1)", "CAST(x AS TINYINT)")
self.validate_identity("CAST(x AS FLOAT4)", "CAST(x AS REAL)")
self.validate_identity("CAST(x AS FLOAT)", "CAST(x AS REAL)")
@@ -1028,6 +1037,13 @@ class TestDuckDB(Validator):
)
self.validate_all(
+ "CAST(x AS VARCHAR(5))",
+ write={
+ "duckdb": "CAST(x AS TEXT)",
+ "postgres": "CAST(x AS TEXT)",
+ },
+ )
+ self.validate_all(
"CAST(x AS DECIMAL(38, 0))",
read={
"snowflake": "CAST(x AS NUMBER)",
diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py
index fdb7e91..280ebbf 100644
--- a/tests/dialects/test_mysql.py
+++ b/tests/dialects/test_mysql.py
@@ -21,6 +21,9 @@ class TestMySQL(Validator):
self.validate_identity("CREATE TABLE foo (a BIGINT, FULLTEXT INDEX (b))")
self.validate_identity("CREATE TABLE foo (a BIGINT, SPATIAL INDEX (b))")
self.validate_identity("ALTER TABLE t1 ADD COLUMN x INT, ALGORITHM=INPLACE, LOCK=EXCLUSIVE")
+ self.validate_identity("ALTER TABLE t ADD INDEX `i` (`c`)")
+ self.validate_identity("ALTER TABLE t ADD UNIQUE `i` (`c`)")
+ self.validate_identity("ALTER TABLE test_table MODIFY COLUMN test_column LONGTEXT")
self.validate_identity(
"CREATE TABLE `oauth_consumer` (`key` VARCHAR(32) NOT NULL, UNIQUE `OAUTH_CONSUMER_KEY` (`key`))"
)
@@ -61,6 +64,10 @@ class TestMySQL(Validator):
"CREATE OR REPLACE VIEW my_view AS SELECT column1 AS `boo`, column2 AS `foo` FROM my_table WHERE column3 = 'some_value' UNION SELECT q.* FROM fruits_table, JSON_TABLE(Fruits, '$[*]' COLUMNS(id VARCHAR(255) PATH '$.$id', value VARCHAR(255) PATH '$.value')) AS q",
)
self.validate_identity(
+ "ALTER TABLE t ADD KEY `i` (`c`)",
+ "ALTER TABLE t ADD INDEX `i` (`c`)",
+ )
+ self.validate_identity(
"CREATE TABLE `foo` (`id` char(36) NOT NULL DEFAULT (uuid()), PRIMARY KEY (`id`), UNIQUE KEY `id` (`id`))",
"CREATE TABLE `foo` (`id` CHAR(36) NOT NULL DEFAULT (UUID()), PRIMARY KEY (`id`), UNIQUE `id` (`id`))",
)
@@ -77,9 +84,6 @@ class TestMySQL(Validator):
"ALTER TABLE test_table MODIFY COLUMN test_column LONGTEXT",
)
self.validate_identity(
- "ALTER TABLE test_table MODIFY COLUMN test_column LONGTEXT",
- )
- self.validate_identity(
"CREATE TABLE t (c DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP) DEFAULT CHARSET=utf8 ROW_FORMAT=DYNAMIC",
"CREATE TABLE t (c DATETIME DEFAULT CURRENT_TIMESTAMP() ON UPDATE CURRENT_TIMESTAMP()) DEFAULT CHARACTER SET=utf8 ROW_FORMAT=DYNAMIC",
)
@@ -113,6 +117,7 @@ class TestMySQL(Validator):
)
def test_identity(self):
+ self.validate_identity("SELECT e.* FROM e STRAIGHT_JOIN p ON e.x = p.y")
self.validate_identity("ALTER TABLE test_table ALTER COLUMN test_column SET DEFAULT 1")
self.validate_identity("SELECT DATE_FORMAT(NOW(), '%Y-%m-%d %H:%i:00.0000')")
self.validate_identity("SELECT @var1 := 1, @var2")
diff --git a/tests/dialects/test_oracle.py b/tests/dialects/test_oracle.py
index 526b0b5..7cc4d72 100644
--- a/tests/dialects/test_oracle.py
+++ b/tests/dialects/test_oracle.py
@@ -1,5 +1,5 @@
-from sqlglot import exp
-from sqlglot.errors import UnsupportedError
+from sqlglot import exp, UnsupportedError
+from sqlglot.dialects.oracle import eliminate_join_marks
from tests.dialects.test_dialect import Validator
@@ -43,6 +43,7 @@ class TestOracle(Validator):
self.validate_identity("SELECT * FROM table_name SAMPLE (25) s")
self.validate_identity("SELECT COUNT(*) * 10 FROM orders SAMPLE (10) SEED (1)")
self.validate_identity("SELECT * FROM V$SESSION")
+ self.validate_identity("SELECT TO_DATE('January 15, 1989, 11:00 A.M.')")
self.validate_identity(
"SELECT last_name, employee_id, manager_id, LEVEL FROM employees START WITH employee_id = 100 CONNECT BY PRIOR employee_id = manager_id ORDER SIBLINGS BY last_name"
)
@@ -249,7 +250,8 @@ class TestOracle(Validator):
self.validate_identity("SELECT e1.x, e2.x FROM e e1, e e2 WHERE e1.y (+) = e2.y")
self.validate_all(
- "SELECT e1.x, e2.x FROM e e1, e e2 WHERE e1.y = e2.y (+)", write={"": UnsupportedError}
+ "SELECT e1.x, e2.x FROM e e1, e e2 WHERE e1.y = e2.y (+)",
+ write={"": UnsupportedError},
)
self.validate_all(
"SELECT e1.x, e2.x FROM e e1, e e2 WHERE e1.y = e2.y (+)",
@@ -413,3 +415,65 @@ WHERE
for query in (f"{body}{start}{connect}", f"{body}{connect}{start}"):
self.validate_identity(query, pretty, pretty=True)
+
+ def test_eliminate_join_marks(self):
+ test_sql = [
+ (
+ "SELECT T1.d, T2.c FROM T1, T2 WHERE T1.x = T2.x (+) and T2.y (+) > 5",
+ "SELECT T1.d, T2.c FROM T1 LEFT JOIN T2 ON T1.x = T2.x AND T2.y > 5",
+ ),
+ (
+ "SELECT T1.d, T2.c FROM T1, T2 WHERE T1.x = T2.x (+) and T2.y (+) IS NULL",
+ "SELECT T1.d, T2.c FROM T1 LEFT JOIN T2 ON T1.x = T2.x AND T2.y IS NULL",
+ ),
+ (
+ "SELECT T1.d, T2.c FROM T1, T2 WHERE T1.x = T2.x (+) and T2.y IS NULL",
+ "SELECT T1.d, T2.c FROM T1 LEFT JOIN T2 ON T1.x = T2.x WHERE T2.y IS NULL",
+ ),
+ (
+ "SELECT T1.d, T2.c FROM T1, T2 WHERE T1.x = T2.x (+) and T1.Z > 4",
+ "SELECT T1.d, T2.c FROM T1 LEFT JOIN T2 ON T1.x = T2.x WHERE T1.Z > 4",
+ ),
+ (
+ "SELECT * FROM table1, table2 WHERE table1.column = table2.column(+)",
+ "SELECT * FROM table1 LEFT JOIN table2 ON table1.column = table2.column",
+ ),
+ (
+ "SELECT * FROM table1, table2, table3, table4 WHERE table1.column = table2.column(+) and table2.column >= table3.column(+) and table1.column = table4.column(+)",
+ "SELECT * FROM table1 LEFT JOIN table2 ON table1.column = table2.column LEFT JOIN table3 ON table2.column >= table3.column LEFT JOIN table4 ON table1.column = table4.column",
+ ),
+ (
+ "SELECT * FROM table1, table2, table3 WHERE table1.column = table2.column(+) and table2.column >= table3.column(+)",
+ "SELECT * FROM table1 LEFT JOIN table2 ON table1.column = table2.column LEFT JOIN table3 ON table2.column >= table3.column",
+ ),
+ (
+ "SELECT table1.id, table2.cloumn1, table3.id FROM table1, table2, (SELECT tableInner1.id FROM tableInner1, tableInner2 WHERE tableInner1.id = tableInner2.id(+)) AS table3 WHERE table1.id = table2.id(+) and table1.id = table3.id(+)",
+ "SELECT table1.id, table2.cloumn1, table3.id FROM table1 LEFT JOIN table2 ON table1.id = table2.id LEFT JOIN (SELECT tableInner1.id FROM tableInner1 LEFT JOIN tableInner2 ON tableInner1.id = tableInner2.id) table3 ON table1.id = table3.id",
+ ),
+ # 2 join marks on one side of predicate
+ (
+ "SELECT * FROM table1, table2 WHERE table1.column = table2.column1(+) + table2.column2(+)",
+ "SELECT * FROM table1 LEFT JOIN table2 ON table1.column = table2.column1 + table2.column2",
+ ),
+ # join mark and expression
+ (
+ "SELECT * FROM table1, table2 WHERE table1.column = table2.column1(+) + 25",
+ "SELECT * FROM table1 LEFT JOIN table2 ON table1.column = table2.column1 + 25",
+ ),
+ ]
+
+ for original, expected in test_sql:
+ with self.subTest(original):
+ self.assertEqual(
+ eliminate_join_marks(self.parse_one(original)).sql(dialect=self.dialect),
+ expected,
+ )
+
+ def test_query_restrictions(self):
+ for restriction in ("READ ONLY", "CHECK OPTION"):
+ for constraint_name in (" CONSTRAINT name", ""):
+ with self.subTest(f"Restriction: {restriction}"):
+ self.validate_identity(f"SELECT * FROM tbl WITH {restriction}{constraint_name}")
+ self.validate_identity(
+ f"CREATE VIEW view AS SELECT * FROM tbl WITH {restriction}{constraint_name}"
+ )
diff --git a/tests/dialects/test_postgres.py b/tests/dialects/test_postgres.py
index 74753be..071677d 100644
--- a/tests/dialects/test_postgres.py
+++ b/tests/dialects/test_postgres.py
@@ -8,6 +8,7 @@ class TestPostgres(Validator):
dialect = "postgres"
def test_postgres(self):
+ self.validate_identity("SHA384(x)")
self.validate_identity(
'CREATE TABLE x (a TEXT COLLATE "de_DE")', "CREATE TABLE x (a TEXT COLLATE de_DE)"
)
@@ -724,6 +725,28 @@ class TestPostgres(Validator):
self.validate_identity("cast(a as FLOAT8)", "CAST(a AS DOUBLE PRECISION)")
self.validate_identity("cast(a as FLOAT4)", "CAST(a AS REAL)")
+ self.validate_all(
+ "1 / DIV(4, 2)",
+ read={
+ "postgres": "1 / DIV(4, 2)",
+ },
+ write={
+ "sqlite": "1 / CAST(CAST(CAST(4 AS REAL) / 2 AS INTEGER) AS REAL)",
+ "duckdb": "1 / CAST(4 // 2 AS DECIMAL)",
+ "bigquery": "1 / CAST(DIV(4, 2) AS NUMERIC)",
+ },
+ )
+ self.validate_all(
+ "CAST(DIV(4, 2) AS DECIMAL(5, 3))",
+ read={
+ "duckdb": "CAST(4 // 2 AS DECIMAL(5, 3))",
+ },
+ write={
+ "duckdb": "CAST(CAST(4 // 2 AS DECIMAL) AS DECIMAL(5, 3))",
+ "postgres": "CAST(DIV(4, 2) AS DECIMAL(5, 3))",
+ },
+ )
+
def test_ddl(self):
# Checks that user-defined types are parsed into DataType instead of Identifier
self.parse_one("CREATE TABLE t (a udt)").this.expressions[0].args["kind"].assert_is(
diff --git a/tests/dialects/test_presto.py b/tests/dialects/test_presto.py
index f1bbcc1..ebb270a 100644
--- a/tests/dialects/test_presto.py
+++ b/tests/dialects/test_presto.py
@@ -564,6 +564,7 @@ class TestPresto(Validator):
self.validate_all(
f"{prefix}'Hello winter \\2603 !'",
write={
+ "oracle": "U'Hello winter \\2603 !'",
"presto": "U&'Hello winter \\2603 !'",
"snowflake": "'Hello winter \\u2603 !'",
"spark": "'Hello winter \\u2603 !'",
@@ -572,6 +573,7 @@ class TestPresto(Validator):
self.validate_all(
f"{prefix}'Hello winter #2603 !' UESCAPE '#'",
write={
+ "oracle": "U'Hello winter \\2603 !'",
"presto": "U&'Hello winter #2603 !' UESCAPE '#'",
"snowflake": "'Hello winter \\u2603 !'",
"spark": "'Hello winter \\u2603 !'",
diff --git a/tests/dialects/test_redshift.py b/tests/dialects/test_redshift.py
index 844fe46..69793c7 100644
--- a/tests/dialects/test_redshift.py
+++ b/tests/dialects/test_redshift.py
@@ -281,6 +281,9 @@ class TestRedshift(Validator):
"redshift": "SELECT DATEADD(MONTH, 18, '2008-02-28')",
"snowflake": "SELECT DATEADD(MONTH, 18, CAST('2008-02-28' AS TIMESTAMP))",
"tsql": "SELECT DATEADD(MONTH, 18, CAST('2008-02-28' AS DATETIME2))",
+ "spark": "SELECT DATE_ADD(MONTH, 18, '2008-02-28')",
+ "spark2": "SELECT ADD_MONTHS('2008-02-28', 18)",
+ "databricks": "SELECT DATE_ADD(MONTH, 18, '2008-02-28')",
},
)
self.validate_all(
@@ -585,3 +588,9 @@ FROM (
self.assertEqual(
ast.sql("redshift"), "SELECT * FROM x AS a, a.b AS c, c.d.e AS f, f.g.h.i.j.k AS l"
)
+
+ def test_join_markers(self):
+ self.validate_identity(
+ "select a.foo, b.bar, a.baz from a, b where a.baz = b.baz (+)",
+ "SELECT a.foo, b.bar, a.baz FROM a, b WHERE a.baz = b.baz (+)",
+ )
diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py
index 9d9371d..1286436 100644
--- a/tests/dialects/test_snowflake.py
+++ b/tests/dialects/test_snowflake.py
@@ -125,6 +125,10 @@ WHERE
"SELECT a:from::STRING, a:from || ' test' ",
"SELECT CAST(GET_PATH(a, 'from') AS TEXT), GET_PATH(a, 'from') || ' test'",
)
+ self.validate_identity(
+ "SELECT a:select",
+ "SELECT GET_PATH(a, 'select')",
+ )
self.validate_identity("x:from", "GET_PATH(x, 'from')")
self.validate_identity(
"value:values::string::int",
@@ -1196,16 +1200,16 @@ WHERE
for constraint_prefix in ("WITH ", ""):
with self.subTest(f"Constraint prefix: {constraint_prefix}"):
self.validate_identity(
- f"CREATE TABLE t (id INT {constraint_prefix}MASKING POLICY p)",
- "CREATE TABLE t (id INT MASKING POLICY p)",
+ f"CREATE TABLE t (id INT {constraint_prefix}MASKING POLICY p.q.r)",
+ "CREATE TABLE t (id INT MASKING POLICY p.q.r)",
)
self.validate_identity(
f"CREATE TABLE t (id INT {constraint_prefix}MASKING POLICY p USING (c1, c2, c3))",
"CREATE TABLE t (id INT MASKING POLICY p USING (c1, c2, c3))",
)
self.validate_identity(
- f"CREATE TABLE t (id INT {constraint_prefix}PROJECTION POLICY p)",
- "CREATE TABLE t (id INT PROJECTION POLICY p)",
+ f"CREATE TABLE t (id INT {constraint_prefix}PROJECTION POLICY p.q.r)",
+ "CREATE TABLE t (id INT PROJECTION POLICY p.q.r)",
)
self.validate_identity(
f"CREATE TABLE t (id INT {constraint_prefix}TAG (key1='value_1', key2='value_2'))",
diff --git a/tests/dialects/test_spark.py b/tests/dialects/test_spark.py
index ecc152f..bff91bf 100644
--- a/tests/dialects/test_spark.py
+++ b/tests/dialects/test_spark.py
@@ -563,6 +563,7 @@ TBLPROPERTIES (
"SELECT DATE_ADD(my_date_column, 1)",
write={
"spark": "SELECT DATE_ADD(my_date_column, 1)",
+ "spark2": "SELECT DATE_ADD(my_date_column, 1)",
"bigquery": "SELECT DATE_ADD(CAST(CAST(my_date_column AS DATETIME) AS DATE), INTERVAL 1 DAY)",
},
)
@@ -675,6 +676,16 @@ TBLPROPERTIES (
"spark": "SELECT ARRAY_SORT(x)",
},
)
+ self.validate_all(
+ "SELECT DATE_ADD(MONTH, 20, col)",
+ read={
+ "spark": "SELECT TIMESTAMPADD(MONTH, 20, col)",
+ },
+ write={
+ "spark": "SELECT DATE_ADD(MONTH, 20, col)",
+ "databricks": "SELECT DATE_ADD(MONTH, 20, col)",
+ },
+ )
def test_bool_or(self):
self.validate_all(
diff --git a/tests/dialects/test_sqlite.py b/tests/dialects/test_sqlite.py
index f3cde0b..46bbadc 100644
--- a/tests/dialects/test_sqlite.py
+++ b/tests/dialects/test_sqlite.py
@@ -202,6 +202,7 @@ class TestSQLite(Validator):
"CREATE TABLE z (a INTEGER UNIQUE PRIMARY KEY AUTOINCREMENT)",
read={
"mysql": "CREATE TABLE z (a INT UNIQUE PRIMARY KEY AUTO_INCREMENT)",
+ "postgres": "CREATE TABLE z (a INT GENERATED BY DEFAULT AS IDENTITY NOT NULL UNIQUE PRIMARY KEY)",
},
write={
"sqlite": "CREATE TABLE z (a INTEGER UNIQUE PRIMARY KEY AUTOINCREMENT)",
diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py
index 92adf7a..7455650 100644
--- a/tests/dialects/test_tsql.py
+++ b/tests/dialects/test_tsql.py
@@ -1,12 +1,18 @@
-from sqlglot import exp, parse, parse_one
+from sqlglot import exp, parse
from tests.dialects.test_dialect import Validator
from sqlglot.errors import ParseError
+from sqlglot.optimizer.annotate_types import annotate_types
class TestTSQL(Validator):
dialect = "tsql"
def test_tsql(self):
+ self.assertEqual(
+ annotate_types(self.validate_identity("SELECT 1 WHERE EXISTS(SELECT 1)")).sql("tsql"),
+ "SELECT 1 WHERE EXISTS(SELECT 1)",
+ )
+
self.validate_identity("CREATE view a.b.c", "CREATE VIEW b.c")
self.validate_identity("DROP view a.b.c", "DROP VIEW b.c")
self.validate_identity("ROUND(x, 1, 0)")
@@ -217,9 +223,9 @@ class TestTSQL(Validator):
"CREATE TABLE [db].[tbl] ([a] INTEGER)",
)
- projection = parse_one("SELECT a = 1", read="tsql").selects[0]
- projection.assert_is(exp.Alias)
- projection.args["alias"].assert_is(exp.Identifier)
+ self.validate_identity("SELECT a = 1", "SELECT 1 AS a").selects[0].assert_is(
+ exp.Alias
+ ).args["alias"].assert_is(exp.Identifier)
self.validate_all(
"IF OBJECT_ID('tempdb.dbo.#TempTableName', 'U') IS NOT NULL DROP TABLE #TempTableName",
@@ -756,12 +762,9 @@ class TestTSQL(Validator):
for view_attr in ("ENCRYPTION", "SCHEMABINDING", "VIEW_METADATA"):
self.validate_identity(f"CREATE VIEW a.b WITH {view_attr} AS SELECT * FROM x")
- expression = parse_one("ALTER TABLE dbo.DocExe DROP CONSTRAINT FK_Column_B", dialect="tsql")
- self.assertIsInstance(expression, exp.AlterTable)
- self.assertIsInstance(expression.args["actions"][0], exp.Drop)
- self.assertEqual(
- expression.sql(dialect="tsql"), "ALTER TABLE dbo.DocExe DROP CONSTRAINT FK_Column_B"
- )
+ self.validate_identity("ALTER TABLE dbo.DocExe DROP CONSTRAINT FK_Column_B").assert_is(
+ exp.AlterTable
+ ).args["actions"][0].assert_is(exp.Drop)
for clustered_keyword in ("CLUSTERED", "NONCLUSTERED"):
self.validate_identity(
@@ -795,10 +798,10 @@ class TestTSQL(Validator):
)
self.validate_all(
- "CREATE TABLE [#temptest] (name VARCHAR)",
+ "CREATE TABLE [#temptest] (name INTEGER)",
read={
- "duckdb": "CREATE TEMPORARY TABLE 'temptest' (name VARCHAR)",
- "tsql": "CREATE TABLE [#temptest] (name VARCHAR)",
+ "duckdb": "CREATE TEMPORARY TABLE 'temptest' (name INTEGER)",
+ "tsql": "CREATE TABLE [#temptest] (name INTEGER)",
},
)
self.validate_all(
@@ -1632,27 +1635,23 @@ WHERE
)
def test_identifier_prefixes(self):
- expr = parse_one("#x", read="tsql")
- self.assertIsInstance(expr, exp.Column)
- self.assertIsInstance(expr.this, exp.Identifier)
- self.assertTrue(expr.this.args.get("temporary"))
- self.assertEqual(expr.sql("tsql"), "#x")
-
- expr = parse_one("##x", read="tsql")
- self.assertIsInstance(expr, exp.Column)
- self.assertIsInstance(expr.this, exp.Identifier)
- self.assertTrue(expr.this.args.get("global"))
- self.assertEqual(expr.sql("tsql"), "##x")
-
- expr = parse_one("@x", read="tsql")
- self.assertIsInstance(expr, exp.Parameter)
- self.assertIsInstance(expr.this, exp.Var)
- self.assertEqual(expr.sql("tsql"), "@x")
+ self.assertTrue(
+ self.validate_identity("#x")
+ .assert_is(exp.Column)
+ .this.assert_is(exp.Identifier)
+ .args.get("temporary")
+ )
+ self.assertTrue(
+ self.validate_identity("##x")
+ .assert_is(exp.Column)
+ .this.assert_is(exp.Identifier)
+ .args.get("global")
+ )
- table = parse_one("select * from @x", read="tsql").args["from"].this
- self.assertIsInstance(table, exp.Table)
- self.assertIsInstance(table.this, exp.Parameter)
- self.assertIsInstance(table.this.this, exp.Var)
+ self.validate_identity("@x").assert_is(exp.Parameter).this.assert_is(exp.Var)
+ self.validate_identity("SELECT * FROM @x").args["from"].this.assert_is(
+ exp.Table
+ ).this.assert_is(exp.Parameter).this.assert_is(exp.Var)
self.validate_all(
"SELECT @x",
@@ -1663,8 +1662,6 @@ WHERE
"tsql": "SELECT @x",
},
)
-
- def test_temp_table(self):
self.validate_all(
"SELECT * FROM #mytemptable",
write={