diff options
Diffstat (limited to 'tests/test_expressions.py')
-rw-r--r-- | tests/test_expressions.py | 68 |
1 files changed, 61 insertions, 7 deletions
diff --git a/tests/test_expressions.py b/tests/test_expressions.py index 37a9720..9bb00de 100644 --- a/tests/test_expressions.py +++ b/tests/test_expressions.py @@ -1,3 +1,4 @@ +import sys import datetime import math import unittest @@ -431,6 +432,31 @@ class TestExpressions(unittest.TestCase): table = expression.find(exp.Table) self.assertEqual(table.alias_column_names, ["a", "b"]) + def test_cast(self): + expression = parse_one("select cast(x as DATE)") + casts = list(expression.find_all(exp.Cast)) + self.assertEqual(len(casts), 1) + + cast = casts[0] + self.assertTrue(cast.to.is_type(exp.DataType.Type.DATE)) + + # check that already cast values arent re-cast if wrapped in a cast to the same type + recast = exp.cast(cast, to=exp.DataType.Type.DATE) + self.assertEqual(recast, cast) + self.assertEqual(recast.sql(), "CAST(x AS DATE)") + + # however, recasting is fine if the types are different + recast = exp.cast(cast, to=exp.DataType.Type.VARCHAR) + self.assertNotEqual(recast, cast) + self.assertEqual(len(list(recast.find_all(exp.Cast))), 2) + self.assertEqual(recast.sql(), "CAST(CAST(x AS DATE) AS VARCHAR)") + + # check that dialect is used when casting strings + self.assertEqual( + exp.cast("x", to="regtype", dialect="postgres").sql(), "CAST(x AS REGTYPE)" + ) + self.assertEqual(exp.cast("`x`", to="date", dialect="hive").sql(), 'CAST("x" AS DATE)') + def test_ctes(self): expression = parse_one("SELECT a FROM x") self.assertEqual(expression.ctes, []) @@ -657,7 +683,10 @@ class TestExpressions(unittest.TestCase): self.assertIsInstance(parse_one("TIME_TO_TIME_STR(a)"), exp.Cast) self.assertIsInstance(parse_one("TIME_TO_UNIX(a)"), exp.TimeToUnix) self.assertIsInstance(parse_one("TIME_STR_TO_DATE(a)"), exp.TimeStrToDate) - self.assertIsInstance(parse_one("TIME_STR_TO_TIME(a)"), exp.TimeStrToTime) + (self.assertIsInstance(parse_one("TIME_STR_TO_TIME(a)"), exp.TimeStrToTime),) + self.assertIsInstance( + parse_one("TIME_STR_TO_TIME(a, 'America/Los_Angeles')"), exp.TimeStrToTime + ) self.assertIsInstance(parse_one("TIME_STR_TO_UNIX(a)"), exp.TimeStrToUnix) self.assertIsInstance(parse_one("TRIM(LEADING 'b' FROM 'bla')"), exp.Trim) self.assertIsInstance(parse_one("TS_OR_DS_ADD(a, 1, 'day')"), exp.TsOrDsAdd) @@ -791,6 +820,7 @@ class TestExpressions(unittest.TestCase): def test_convert(self): from collections import namedtuple + import pytz PointTuple = namedtuple("Point", ["x", "y"]) @@ -809,11 +839,17 @@ class TestExpressions(unittest.TestCase): ({"x": None}, "MAP(ARRAY('x'), ARRAY(NULL))"), ( datetime.datetime(2022, 10, 1, 1, 1, 1, 1), - "TIME_STR_TO_TIME('2022-10-01 01:01:01.000001+00:00')", + "TIME_STR_TO_TIME('2022-10-01 01:01:01.000001')", ), ( datetime.datetime(2022, 10, 1, 1, 1, 1, tzinfo=datetime.timezone.utc), - "TIME_STR_TO_TIME('2022-10-01 01:01:01+00:00')", + "TIME_STR_TO_TIME('2022-10-01 01:01:01+00:00', 'UTC')", + ), + ( + pytz.timezone("America/Los_Angeles").localize( + datetime.datetime(2022, 10, 1, 1, 1, 1) + ), + "TIME_STR_TO_TIME('2022-10-01 01:01:01-07:00', 'America/Los_Angeles')", ), (datetime.date(2022, 10, 1), "DATE_STR_TO_DATE('2022-10-01')"), (math.nan, "NULL"), @@ -829,6 +865,21 @@ class TestExpressions(unittest.TestCase): "MAP_FROM_ARRAYS(ARRAY('test'), ARRAY('value'))", ) + @unittest.skipUnless(sys.version_info >= (3, 9), "zoneinfo only available from python 3.9+") + def test_convert_python39(self): + import zoneinfo + + for value, expected in [ + ( + datetime.datetime( + 2022, 10, 1, 1, 1, 1, tzinfo=zoneinfo.ZoneInfo("America/Los_Angeles") + ), + "TIME_STR_TO_TIME('2022-10-01 01:01:01-07:00', 'America/Los_Angeles')", + ) + ]: + with self.subTest(value): + self.assertEqual(exp.convert(value).sql(), expected) + def test_comment_alias(self): sql = """ SELECT @@ -993,16 +1044,15 @@ FROM foo""", self.assertEqual(exp.DataType.build("UNKNOWN", dialect="bigquery").sql(), "UNKNOWN") self.assertEqual(exp.DataType.build("UNKNOWN", dialect="snowflake").sql(), "UNKNOWN") self.assertEqual(exp.DataType.build("TIMESTAMP", dialect="bigquery").sql(), "TIMESTAMPTZ") - self.assertEqual( - exp.DataType.build("struct<x int>", dialect="spark").sql(), "STRUCT<x INT>" - ) self.assertEqual(exp.DataType.build("USER-DEFINED").sql(), "USER-DEFINED") - self.assertEqual(exp.DataType.build("ARRAY<UNKNOWN>").sql(), "ARRAY<UNKNOWN>") self.assertEqual(exp.DataType.build("ARRAY<NULL>").sql(), "ARRAY<NULL>") self.assertEqual(exp.DataType.build("varchar(100) collate 'en-ci'").sql(), "VARCHAR(100)") self.assertEqual(exp.DataType.build("int[3]").sql(dialect="duckdb"), "INT[3]") self.assertEqual(exp.DataType.build("int[3][3]").sql(dialect="duckdb"), "INT[3][3]") + self.assertEqual( + exp.DataType.build("struct<x int>", dialect="spark").sql(), "STRUCT<x INT>" + ) with self.assertRaises(ParseError): exp.DataType.build("varchar(") @@ -1107,6 +1157,10 @@ FROM foo""", dtype = exp.DataType.build("a.b.c", udt=True) assert dtype.is_type("a.b.c") + dtype = exp.DataType.build("Nullable(Int32)", dialect="clickhouse") + assert dtype.is_type("int") + assert not dtype.is_type("int", check_nullable=True) + with self.assertRaises(ParseError): exp.DataType.build("foo") |