summaryrefslogtreecommitdiffstats
path: root/tests/test_expressions.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--tests/test_expressions.py107
1 files changed, 87 insertions, 20 deletions
diff --git a/tests/test_expressions.py b/tests/test_expressions.py
index 1395b24..62227cb 100644
--- a/tests/test_expressions.py
+++ b/tests/test_expressions.py
@@ -1,3 +1,4 @@
+import sys
import datetime
import math
import unittest
@@ -215,12 +216,14 @@ class TestExpressions(unittest.TestCase):
self.assertEqual(exp.table_name(exp.to_table("a.b.c.d.e", dialect="bigquery")), "a.b.c.d.e")
self.assertEqual(exp.table_name(exp.to_table("'@foo'", dialect="snowflake")), "'@foo'")
self.assertEqual(exp.table_name(exp.to_table("@foo", dialect="snowflake")), "@foo")
+ self.assertEqual(exp.table_name(bq_dashed_table, identify=True), '"a-1"."b"."c"')
self.assertEqual(
exp.table_name(parse_one("foo.`{bar,er}`", read="databricks"), dialect="databricks"),
"foo.`{bar,er}`",
)
-
- self.assertEqual(exp.table_name(bq_dashed_table, identify=True), '"a-1"."b"."c"')
+ self.assertEqual(
+ exp.table_name(parse_one("/*c*/foo.bar", into=exp.Table), identify=True), '"foo"."bar"'
+ )
def test_table(self):
self.assertEqual(exp.table_("a", alias="b"), parse_one("select * from a b").find(exp.Table))
@@ -349,6 +352,7 @@ class TestExpressions(unittest.TestCase):
)
self.assertIsInstance(exp.func("instr", "x", "b", dialect="mysql"), exp.StrPosition)
+ self.assertIsInstance(exp.func("instr", "x", "b", dialect="sqlite"), exp.StrPosition)
self.assertIsInstance(exp.func("bla", 1, "foo"), exp.Anonymous)
self.assertIsInstance(
exp.func("cast", this=exp.Literal.number(5), to=exp.DataType.build("DOUBLE")),
@@ -431,6 +435,31 @@ class TestExpressions(unittest.TestCase):
table = expression.find(exp.Table)
self.assertEqual(table.alias_column_names, ["a", "b"])
+ def test_cast(self):
+ expression = parse_one("select cast(x as DATE)")
+ casts = list(expression.find_all(exp.Cast))
+ self.assertEqual(len(casts), 1)
+
+ cast = casts[0]
+ self.assertTrue(cast.to.is_type(exp.DataType.Type.DATE))
+
+ # check that already cast values arent re-cast if wrapped in a cast to the same type
+ recast = exp.cast(cast, to=exp.DataType.Type.DATE)
+ self.assertEqual(recast, cast)
+ self.assertEqual(recast.sql(), "CAST(x AS DATE)")
+
+ # however, recasting is fine if the types are different
+ recast = exp.cast(cast, to=exp.DataType.Type.VARCHAR)
+ self.assertNotEqual(recast, cast)
+ self.assertEqual(len(list(recast.find_all(exp.Cast))), 2)
+ self.assertEqual(recast.sql(), "CAST(CAST(x AS DATE) AS VARCHAR)")
+
+ # check that dialect is used when casting strings
+ self.assertEqual(
+ exp.cast("x", to="regtype", dialect="postgres").sql(), "CAST(x AS REGTYPE)"
+ )
+ self.assertEqual(exp.cast("`x`", to="date", dialect="hive").sql(), 'CAST("x" AS DATE)')
+
def test_ctes(self):
expression = parse_one("SELECT a FROM x")
self.assertEqual(expression.ctes, [])
@@ -648,6 +677,8 @@ class TestExpressions(unittest.TestCase):
self.assertIsInstance(parse_one("STR_POSITION(a, 'test')"), exp.StrPosition)
self.assertIsInstance(parse_one("STR_TO_UNIX(a, 'format')"), exp.StrToUnix)
self.assertIsInstance(parse_one("STRUCT_EXTRACT(a, 'test')"), exp.StructExtract)
+ self.assertIsInstance(parse_one("SUBSTR('a', 1, 1)"), exp.Substring)
+ self.assertIsInstance(parse_one("SUBSTRING('a', 1, 1)"), exp.Substring)
self.assertIsInstance(parse_one("SUM(a)"), exp.Sum)
self.assertIsInstance(parse_one("SQRT(a)"), exp.Sqrt)
self.assertIsInstance(parse_one("STDDEV(a)"), exp.Stddev)
@@ -657,7 +688,10 @@ class TestExpressions(unittest.TestCase):
self.assertIsInstance(parse_one("TIME_TO_TIME_STR(a)"), exp.Cast)
self.assertIsInstance(parse_one("TIME_TO_UNIX(a)"), exp.TimeToUnix)
self.assertIsInstance(parse_one("TIME_STR_TO_DATE(a)"), exp.TimeStrToDate)
- self.assertIsInstance(parse_one("TIME_STR_TO_TIME(a)"), exp.TimeStrToTime)
+ (self.assertIsInstance(parse_one("TIME_STR_TO_TIME(a)"), exp.TimeStrToTime),)
+ self.assertIsInstance(
+ parse_one("TIME_STR_TO_TIME(a, 'America/Los_Angeles')"), exp.TimeStrToTime
+ )
self.assertIsInstance(parse_one("TIME_STR_TO_UNIX(a)"), exp.TimeStrToUnix)
self.assertIsInstance(parse_one("TRIM(LEADING 'b' FROM 'bla')"), exp.Trim)
self.assertIsInstance(parse_one("TS_OR_DS_ADD(a, 1, 'day')"), exp.TsOrDsAdd)
@@ -791,6 +825,7 @@ class TestExpressions(unittest.TestCase):
def test_convert(self):
from collections import namedtuple
+ import pytz
PointTuple = namedtuple("Point", ["x", "y"])
@@ -809,11 +844,17 @@ class TestExpressions(unittest.TestCase):
({"x": None}, "MAP(ARRAY('x'), ARRAY(NULL))"),
(
datetime.datetime(2022, 10, 1, 1, 1, 1, 1),
- "TIME_STR_TO_TIME('2022-10-01 01:01:01.000001+00:00')",
+ "TIME_STR_TO_TIME('2022-10-01 01:01:01.000001')",
),
(
datetime.datetime(2022, 10, 1, 1, 1, 1, tzinfo=datetime.timezone.utc),
- "TIME_STR_TO_TIME('2022-10-01 01:01:01+00:00')",
+ "TIME_STR_TO_TIME('2022-10-01 01:01:01+00:00', 'UTC')",
+ ),
+ (
+ pytz.timezone("America/Los_Angeles").localize(
+ datetime.datetime(2022, 10, 1, 1, 1, 1)
+ ),
+ "TIME_STR_TO_TIME('2022-10-01 01:01:01-07:00', 'America/Los_Angeles')",
),
(datetime.date(2022, 10, 1), "DATE_STR_TO_DATE('2022-10-01')"),
(math.nan, "NULL"),
@@ -829,6 +870,21 @@ class TestExpressions(unittest.TestCase):
"MAP_FROM_ARRAYS(ARRAY('test'), ARRAY('value'))",
)
+ @unittest.skipUnless(sys.version_info >= (3, 9), "zoneinfo only available from python 3.9+")
+ def test_convert_python39(self):
+ import zoneinfo
+
+ for value, expected in [
+ (
+ datetime.datetime(
+ 2022, 10, 1, 1, 1, 1, tzinfo=zoneinfo.ZoneInfo("America/Los_Angeles")
+ ),
+ "TIME_STR_TO_TIME('2022-10-01 01:01:01-07:00', 'America/Los_Angeles')",
+ )
+ ]:
+ with self.subTest(value):
+ self.assertEqual(exp.convert(value).sql(), expected)
+
def test_comment_alias(self):
sql = """
SELECT
@@ -881,15 +937,13 @@ FROM foo""",
def test_to_interval(self):
self.assertEqual(exp.to_interval("1day").sql(), "INTERVAL '1' DAY")
self.assertEqual(exp.to_interval(" 5 months").sql(), "INTERVAL '5' MONTHS")
- with self.assertRaises(ValueError):
- exp.to_interval("bla")
+ self.assertEqual(exp.to_interval("-2 day").sql(), "INTERVAL '-2' DAY")
self.assertEqual(exp.to_interval(exp.Literal.string("1day")).sql(), "INTERVAL '1' DAY")
+ self.assertEqual(exp.to_interval(exp.Literal.string("-2 day")).sql(), "INTERVAL '-2' DAY")
self.assertEqual(
exp.to_interval(exp.Literal.string(" 5 months")).sql(), "INTERVAL '5' MONTHS"
)
- with self.assertRaises(ValueError):
- exp.to_interval(exp.Literal.string("bla"))
def test_to_table(self):
table_only = exp.to_table("table_name")
@@ -984,7 +1038,6 @@ FROM foo""",
self.assertEqual(exp.DataType.build("GEOGRAPHY").sql(), "GEOGRAPHY")
self.assertEqual(exp.DataType.build("GEOMETRY").sql(), "GEOMETRY")
self.assertEqual(exp.DataType.build("STRUCT").sql(), "STRUCT")
- self.assertEqual(exp.DataType.build("NULLABLE").sql(), "NULLABLE")
self.assertEqual(exp.DataType.build("HLLSKETCH", dialect="redshift").sql(), "HLLSKETCH")
self.assertEqual(exp.DataType.build("HSTORE", dialect="postgres").sql(), "HSTORE")
self.assertEqual(exp.DataType.build("NULL").sql(), "NULL")
@@ -993,14 +1046,15 @@ FROM foo""",
self.assertEqual(exp.DataType.build("UNKNOWN", dialect="bigquery").sql(), "UNKNOWN")
self.assertEqual(exp.DataType.build("UNKNOWN", dialect="snowflake").sql(), "UNKNOWN")
self.assertEqual(exp.DataType.build("TIMESTAMP", dialect="bigquery").sql(), "TIMESTAMPTZ")
- self.assertEqual(
- exp.DataType.build("struct<x int>", dialect="spark").sql(), "STRUCT<x INT>"
- )
self.assertEqual(exp.DataType.build("USER-DEFINED").sql(), "USER-DEFINED")
-
self.assertEqual(exp.DataType.build("ARRAY<UNKNOWN>").sql(), "ARRAY<UNKNOWN>")
self.assertEqual(exp.DataType.build("ARRAY<NULL>").sql(), "ARRAY<NULL>")
self.assertEqual(exp.DataType.build("varchar(100) collate 'en-ci'").sql(), "VARCHAR(100)")
+ self.assertEqual(exp.DataType.build("int[3]").sql(dialect="duckdb"), "INT[3]")
+ self.assertEqual(exp.DataType.build("int[3][3]").sql(dialect="duckdb"), "INT[3][3]")
+ self.assertEqual(
+ exp.DataType.build("struct<x int>", dialect="spark").sql(), "STRUCT<x INT>"
+ )
with self.assertRaises(ParseError):
exp.DataType.build("varchar(")
@@ -1011,12 +1065,18 @@ FROM foo""",
"ALTER TABLE t1 RENAME TO t2",
)
- def test_is_negative(self):
- self.assertTrue(parse_one("-1").is_negative)
- self.assertTrue(parse_one("- 1.0").is_negative)
- self.assertTrue(exp.Literal.number("-1").is_negative)
- self.assertFalse(parse_one("1").is_negative)
- self.assertFalse(parse_one("x").is_negative)
+ def test_to_py(self):
+ self.assertEqual(parse_one("- -1").to_py(), 1)
+ self.assertIs(parse_one("TRUE").to_py(), True)
+ self.assertIs(parse_one("1").to_py(), 1)
+ self.assertIs(parse_one("'1'").to_py(), "1")
+ self.assertIs(parse_one("null").to_py(), None)
+
+ with self.assertRaises(ValueError):
+ parse_one("x").to_py()
+
+ def test_is_int(self):
+ self.assertTrue(parse_one("- -1").is_int)
def test_is_star(self):
assert parse_one("*").is_star
@@ -1099,6 +1159,10 @@ FROM foo""",
dtype = exp.DataType.build("a.b.c", udt=True)
assert dtype.is_type("a.b.c")
+ dtype = exp.DataType.build("Nullable(Int32)", dialect="clickhouse")
+ assert dtype.is_type("int")
+ assert not dtype.is_type("int", check_nullable=True)
+
with self.assertRaises(ParseError):
exp.DataType.build("foo")
@@ -1114,3 +1178,6 @@ FROM foo""",
AssertionError, "x is not <class 'sqlglot.expressions.Identifier'>\\."
):
parse_one("x").assert_is(exp.Identifier)
+
+ def test_parse_identifier(self):
+ self.assertEqual(exp.parse_identifier("a ' b"), exp.to_identifier("a ' b"))