diff options
Diffstat (limited to 'tests/test_expressions.py')
-rw-r--r-- | tests/test_expressions.py | 53 |
1 files changed, 14 insertions, 39 deletions
diff --git a/tests/test_expressions.py b/tests/test_expressions.py index eaef022..716e457 100644 --- a/tests/test_expressions.py +++ b/tests/test_expressions.py @@ -26,9 +26,7 @@ class TestExpressions(unittest.TestCase): parse_one("ROW() OVER(Partition by y)"), parse_one("ROW() OVER (partition BY y)"), ) - self.assertEqual( - parse_one("TO_DATE(x)", read="hive"), parse_one("ts_or_ds_to_date(x)") - ) + self.assertEqual(parse_one("TO_DATE(x)", read="hive"), parse_one("ts_or_ds_to_date(x)")) def test_find(self): expression = parse_one("CREATE TABLE x STORED AS PARQUET AS SELECT * FROM y") @@ -87,9 +85,7 @@ class TestExpressions(unittest.TestCase): self.assertIsNone(column.find_ancestor(exp.Join)) def test_alias_or_name(self): - expression = parse_one( - "SELECT a, b AS B, c + d AS e, *, 'zz', 'zz' AS z FROM foo as bar, baz" - ) + expression = parse_one("SELECT a, b AS B, c + d AS e, *, 'zz', 'zz' AS z FROM foo as bar, baz") self.assertEqual( [e.alias_or_name for e in expression.expressions], ["a", "B", "e", "*", "zz", "z"], @@ -118,9 +114,7 @@ class TestExpressions(unittest.TestCase): ) def test_named_selects(self): - expression = parse_one( - "SELECT a, b AS B, c + d AS e, *, 'zz', 'zz' AS z FROM foo as bar, baz" - ) + expression = parse_one("SELECT a, b AS B, c + d AS e, *, 'zz', 'zz' AS z FROM foo as bar, baz") self.assertEqual(expression.named_selects, ["a", "B", "e", "*", "zz", "z"]) expression = parse_one( @@ -196,15 +190,9 @@ class TestExpressions(unittest.TestCase): def test_sql(self): self.assertEqual(parse_one("x + y * 2").sql(), "x + y * 2") - self.assertEqual( - parse_one('select "x"').sql(dialect="hive", pretty=True), "SELECT\n `x`" - ) - self.assertEqual( - parse_one("X + y").sql(identify=True, normalize=True), '"x" + "y"' - ) - self.assertEqual( - parse_one("SUM(X)").sql(identify=True, normalize=True), 'SUM("x")' - ) + self.assertEqual(parse_one('select "x"').sql(dialect="hive", pretty=True), "SELECT\n `x`") + self.assertEqual(parse_one("X + y").sql(identify=True, normalize=True), '"x" + "y"') + self.assertEqual(parse_one("SUM(X)").sql(identify=True, normalize=True), 'SUM("x")') def test_transform_with_arguments(self): expression = parse_one("a") @@ -229,15 +217,11 @@ class TestExpressions(unittest.TestCase): return node actual_expression_1 = expression.transform(fun) - self.assertEqual( - actual_expression_1.sql(dialect="presto"), "IF(c - 2 > 0, c - 2, b)" - ) + self.assertEqual(actual_expression_1.sql(dialect="presto"), "IF(c - 2 > 0, c - 2, b)") self.assertIsNot(actual_expression_1, expression) actual_expression_2 = expression.transform(fun, copy=False) - self.assertEqual( - actual_expression_2.sql(dialect="presto"), "IF(c - 2 > 0, c - 2, b)" - ) + self.assertEqual(actual_expression_2.sql(dialect="presto"), "IF(c - 2 > 0, c - 2, b)") self.assertIs(actual_expression_2, expression) with self.assertRaises(ValueError): @@ -274,12 +258,8 @@ class TestExpressions(unittest.TestCase): expression = parse_one("SELECT * FROM (SELECT * FROM x)") self.assertEqual(len(list(expression.walk())), 9) self.assertEqual(len(list(expression.walk(bfs=False))), 9) - self.assertTrue( - all(isinstance(e, exp.Expression) for e, _, _ in expression.walk()) - ) - self.assertTrue( - all(isinstance(e, exp.Expression) for e, _, _ in expression.walk(bfs=False)) - ) + self.assertTrue(all(isinstance(e, exp.Expression) for e, _, _ in expression.walk())) + self.assertTrue(all(isinstance(e, exp.Expression) for e, _, _ in expression.walk(bfs=False))) def test_functions(self): self.assertIsInstance(parse_one("ABS(a)"), exp.Abs) @@ -303,9 +283,7 @@ class TestExpressions(unittest.TestCase): self.assertIsInstance(parse_one("IF(a, b, c)"), exp.If) self.assertIsInstance(parse_one("INITCAP(a)"), exp.Initcap) self.assertIsInstance(parse_one("JSON_EXTRACT(a, '$.name')"), exp.JSONExtract) - self.assertIsInstance( - parse_one("JSON_EXTRACT_SCALAR(a, '$.name')"), exp.JSONExtractScalar - ) + self.assertIsInstance(parse_one("JSON_EXTRACT_SCALAR(a, '$.name')"), exp.JSONExtractScalar) self.assertIsInstance(parse_one("LEAST(a, b)"), exp.Least) self.assertIsInstance(parse_one("LN(a)"), exp.Ln) self.assertIsInstance(parse_one("LOG10(a)"), exp.Log10) @@ -334,6 +312,7 @@ class TestExpressions(unittest.TestCase): self.assertIsInstance(parse_one("TIME_STR_TO_DATE(a)"), exp.TimeStrToDate) self.assertIsInstance(parse_one("TIME_STR_TO_TIME(a)"), exp.TimeStrToTime) self.assertIsInstance(parse_one("TIME_STR_TO_UNIX(a)"), exp.TimeStrToUnix) + self.assertIsInstance(parse_one("TRIM(LEADING 'b' FROM 'bla')"), exp.Trim) self.assertIsInstance(parse_one("TS_OR_DS_ADD(a, 1, 'day')"), exp.TsOrDsAdd) self.assertIsInstance(parse_one("TS_OR_DS_TO_DATE(a)"), exp.TsOrDsToDate) self.assertIsInstance(parse_one("TS_OR_DS_TO_DATE_STR(a)"), exp.Substring) @@ -404,12 +383,8 @@ class TestExpressions(unittest.TestCase): self.assertFalse(exp.to_identifier("x").quoted) def test_function_normalizer(self): - self.assertEqual( - parse_one("HELLO()").sql(normalize_functions="lower"), "hello()" - ) - self.assertEqual( - parse_one("hello()").sql(normalize_functions="upper"), "HELLO()" - ) + self.assertEqual(parse_one("HELLO()").sql(normalize_functions="lower"), "hello()") + self.assertEqual(parse_one("hello()").sql(normalize_functions="upper"), "HELLO()") self.assertEqual(parse_one("heLLO()").sql(normalize_functions=None), "heLLO()") self.assertEqual(parse_one("SUM(x)").sql(normalize_functions="lower"), "sum(x)") self.assertEqual(parse_one("sum(x)").sql(normalize_functions="upper"), "SUM(x)") |