diff options
Diffstat (limited to '')
-rw-r--r-- | tests/dialects/test_snowflake.py | 2 | ||||
-rw-r--r-- | tests/dialects/test_tsql.py | 24 | ||||
-rw-r--r-- | tests/test_build.py | 6 | ||||
-rw-r--r-- | tests/test_parser.py | 2 | ||||
-rw-r--r-- | tests/test_schema.py | 5 |
5 files changed, 34 insertions, 5 deletions
diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py index bca5aaa..e3d0cff 100644 --- a/tests/dialects/test_snowflake.py +++ b/tests/dialects/test_snowflake.py @@ -496,7 +496,7 @@ FROM cs.telescope.dag_report, TABLE(FLATTEN(input => SPLIT(operators, ','))) AS f.value AS "Contact", f1.value['type'] AS "Type", f1.value['content'] AS "Details" -FROM persons AS p, LATERAL FLATTEN(input => p.c, path => 'contact') f, LATERAL FLATTEN(input => f.value['business']) f1""", +FROM persons AS p, LATERAL FLATTEN(input => p.c, path => 'contact') AS f, LATERAL FLATTEN(input => f.value['business']) AS f1""", }, pretty=True, ) diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py index afdd48a..e4c6e60 100644 --- a/tests/dialects/test_tsql.py +++ b/tests/dialects/test_tsql.py @@ -371,13 +371,19 @@ class TestTSQL(Validator): self.validate_all( "SELECT t.x, y.z FROM x CROSS APPLY tvfTest(t.x)y(z)", write={ - "spark": "SELECT t.x, y.z FROM x JOIN LATERAL TVFTEST(t.x) y AS z", + "spark": "SELECT t.x, y.z FROM x JOIN LATERAL TVFTEST(t.x) AS y(z)", }, ) self.validate_all( "SELECT t.x, y.z FROM x OUTER APPLY tvfTest(t.x)y(z)", write={ - "spark": "SELECT t.x, y.z FROM x LEFT JOIN LATERAL TVFTEST(t.x) y AS z", + "spark": "SELECT t.x, y.z FROM x LEFT JOIN LATERAL TVFTEST(t.x) AS y(z)", + }, + ) + self.validate_all( + "SELECT t.x, y.z FROM x OUTER APPLY a.b.tvfTest(t.x)y(z)", + write={ + "spark": "SELECT t.x, y.z FROM x LEFT JOIN LATERAL a.b.TVFTEST(t.x) AS y(z)", }, ) @@ -421,3 +427,17 @@ class TestTSQL(Validator): self.validate_all( "SELECT FORMAT(num_col, 'c')", write={"spark": "SELECT FORMAT_NUMBER(num_col, 'c')"} ) + + def test_string(self): + self.validate_all( + "SELECT N'test'", + write={"spark": "SELECT 'test'"}, + ) + self.validate_all( + "SELECT n'test'", + write={"spark": "SELECT 'test'"}, + ) + self.validate_all( + "SELECT '''test'''", + write={"spark": r"SELECT '\'test\''"}, + ) diff --git a/tests/test_build.py b/tests/test_build.py index 721c868..b014a3a 100644 --- a/tests/test_build.py +++ b/tests/test_build.py @@ -473,6 +473,12 @@ class TestBuild(unittest.TestCase): (lambda: exp.values([("1", 2)]), "VALUES ('1', 2)"), (lambda: exp.values([("1", 2)], "alias"), "(VALUES ('1', 2)) AS alias"), (lambda: exp.values([("1", 2), ("2", 3)]), "VALUES ('1', 2), ('2', 3)"), + ( + lambda: exp.values( + [("1", 2, None), ("2", 3, None)], "alias", ["col1", "col2", "col3"] + ), + "(VALUES ('1', 2, NULL), ('2', 3, NULL)) AS alias(col1, col2, col3)", + ), (lambda: exp.delete("y", where="x > 1"), "DELETE FROM y WHERE x > 1"), (lambda: exp.delete("y", where=exp.and_("x > 1")), "DELETE FROM y WHERE x > 1"), ]: diff --git a/tests/test_parser.py b/tests/test_parser.py index fa7b589..0be15e4 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -85,7 +85,7 @@ class TestParser(unittest.TestCase): self.assertEqual(len(parse_one("select * from (select 1) x cross join y").args["joins"]), 1) self.assertEqual( parse_one("""SELECT * FROM x CROSS JOIN y, z LATERAL VIEW EXPLODE(y)""").sql(), - """SELECT * FROM x, z LATERAL VIEW EXPLODE(y) CROSS JOIN y""", + """SELECT * FROM x, z CROSS JOIN y LATERAL VIEW EXPLODE(y)""", ) def test_command(self): diff --git a/tests/test_schema.py b/tests/test_schema.py index f1e12a2..6c1ca9c 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -1,6 +1,6 @@ import unittest -from sqlglot import exp, to_table +from sqlglot import exp, parse_one, to_table from sqlglot.errors import SchemaError from sqlglot.schema import MappingSchema, ensure_schema @@ -181,3 +181,6 @@ class TestSchema(unittest.TestCase): schema.get_column_type(exp.Table(this="c", db="b", catalog="a"), "d").this, exp.DataType.Type.VARCHAR, ) + + schema = MappingSchema({"foo": {"bar": parse_one("INT", into=exp.DataType)}}) + self.assertEqual(schema.get_column_type("foo", "bar").this, exp.DataType.Type.INT) |