diff options
Diffstat (limited to 'tests/test_expressions.py')
-rw-r--r-- | tests/test_expressions.py | 61 |
1 files changed, 59 insertions, 2 deletions
diff --git a/tests/test_expressions.py b/tests/test_expressions.py index f68ced2..5d1f810 100644 --- a/tests/test_expressions.py +++ b/tests/test_expressions.py @@ -2,7 +2,7 @@ import datetime import math import unittest -from sqlglot import alias, exp, parse_one +from sqlglot import ParseError, alias, exp, parse_one class TestExpressions(unittest.TestCase): @@ -188,6 +188,7 @@ class TestExpressions(unittest.TestCase): def test_table(self): self.assertEqual(exp.table_("a", alias="b"), parse_one("select * from a b").find(exp.Table)) + self.assertEqual(exp.table_("a", "").sql(), "a") def test_replace_tables(self): self.assertEqual( @@ -666,7 +667,7 @@ class TestExpressions(unittest.TestCase): (True, "TRUE"), ((1, "2", None), "(1, '2', NULL)"), ([1, "2", None], "ARRAY(1, '2', NULL)"), - ({"x": None}, "MAP('x', NULL)"), + ({"x": None}, "MAP(ARRAY('x'), ARRAY(NULL))"), ( datetime.datetime(2022, 10, 1, 1, 1, 1, 1), "TIME_STR_TO_TIME('2022-10-01T01:01:01.000001+00:00')", @@ -681,6 +682,11 @@ class TestExpressions(unittest.TestCase): with self.subTest(value): self.assertEqual(exp.convert(value).sql(), expected) + self.assertEqual( + exp.convert({"test": "value"}).sql(dialect="spark"), + "MAP_FROM_ARRAYS(ARRAY('test'), ARRAY('value'))", + ) + def test_comment_alias(self): sql = """ SELECT @@ -841,6 +847,9 @@ FROM foo""", ) self.assertEqual(exp.DataType.build("USER-DEFINED").sql(), "USER-DEFINED") + self.assertEqual(exp.DataType.build("ARRAY<UNKNOWN>").sql(), "ARRAY<UNKNOWN>") + self.assertEqual(exp.DataType.build("ARRAY<NULL>").sql(), "ARRAY<NULL>") + def test_rename_table(self): self.assertEqual( exp.rename_table("t1", "t2").sql(), @@ -879,3 +888,51 @@ FROM foo""", ast.meta["some_other_meta_key"] = "some_other_meta_value" self.assertEqual(ast.meta.get("some_other_meta_key"), "some_other_meta_value") + + def test_unnest(self): + ast = parse_one("SELECT (((1)))") + self.assertIs(ast.selects[0].unnest(), ast.find(exp.Literal)) + + ast = parse_one("SELECT * FROM (((SELECT * FROM t)))") + self.assertIs(ast.args["from"].this.unnest(), list(ast.find_all(exp.Select))[1]) + + ast = parse_one("SELECT * FROM ((((SELECT * FROM t))) AS foo)") + second_subquery = ast.args["from"].this.this + innermost_subquery = list(ast.find_all(exp.Select))[1].parent + self.assertIs(second_subquery, innermost_subquery.unwrap()) + + def test_is_type(self): + ast = parse_one("CAST(x AS VARCHAR)") + assert ast.is_type("VARCHAR") + assert not ast.is_type("VARCHAR(5)") + assert not ast.is_type("FLOAT") + + ast = parse_one("CAST(x AS VARCHAR(5))") + assert ast.is_type("VARCHAR") + assert ast.is_type("VARCHAR(5)") + assert not ast.is_type("VARCHAR(4)") + assert not ast.is_type("FLOAT") + + ast = parse_one("CAST(x AS ARRAY<INT>)") + assert ast.is_type("ARRAY") + assert ast.is_type("ARRAY<INT>") + assert not ast.is_type("ARRAY<FLOAT>") + assert not ast.is_type("INT") + + ast = parse_one("CAST(x AS ARRAY)") + assert ast.is_type("ARRAY") + assert not ast.is_type("ARRAY<INT>") + assert not ast.is_type("ARRAY<FLOAT>") + assert not ast.is_type("INT") + + ast = parse_one("CAST(x AS STRUCT<a INT, b FLOAT>)") + assert ast.is_type("STRUCT") + assert ast.is_type("STRUCT<a INT, b FLOAT>") + assert not ast.is_type("STRUCT<a VARCHAR, b INT>") + + dtype = exp.DataType.build("foo", udt=True) + assert dtype.is_type("foo") + assert not dtype.is_type("bar") + + with self.assertRaises(ParseError): + exp.DataType.build("foo") |