summaryrefslogtreecommitdiffstats
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--tests/dialects/test_bigquery.py3
-rw-r--r--tests/dialects/test_clickhouse.py97
-rw-r--r--tests/dialects/test_dialect.py1
-rw-r--r--tests/dialects/test_duckdb.py1
-rw-r--r--tests/dialects/test_hive.py1
-rw-r--r--tests/dialects/test_mysql.py2
-rw-r--r--tests/dialects/test_postgres.py13
-rw-r--r--tests/dialects/test_presto.py19
-rw-r--r--tests/dialects/test_redshift.py41
-rw-r--r--tests/dialects/test_spark.py25
-rw-r--r--tests/dialects/test_teradata.py21
-rw-r--r--tests/fixtures/identity.sql8
-rw-r--r--tests/test_parser.py3
-rw-r--r--tests/test_serde.py2
14 files changed, 223 insertions, 14 deletions
diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py
index 99d8a3c..05ded11 100644
--- a/tests/dialects/test_bigquery.py
+++ b/tests/dialects/test_bigquery.py
@@ -6,6 +6,8 @@ class TestBigQuery(Validator):
dialect = "bigquery"
def test_bigquery(self):
+ self.validate_identity("SELECT * FROM x-0.a")
+ self.validate_identity("SELECT * FROM pivot CROSS JOIN foo")
self.validate_identity("SAFE_CAST(x AS STRING)")
self.validate_identity("SELECT * FROM a-b-c.mydataset.mytable")
self.validate_identity("SELECT * FROM abc-def-ghi")
@@ -35,6 +37,7 @@ class TestBigQuery(Validator):
"CREATE TABLE IF NOT EXISTS foo AS SELECT * FROM bla EXCEPT DISTINCT (SELECT * FROM bar) LIMIT 0"
)
+ self.validate_all("SELECT 1 AS hash", write={"bigquery": "SELECT 1 AS `hash`"})
self.validate_all('x <> ""', write={"bigquery": "x <> ''"})
self.validate_all('x <> """"""', write={"bigquery": "x <> ''"})
self.validate_all("x <> ''''''", write={"bigquery": "x <> ''"})
diff --git a/tests/dialects/test_clickhouse.py b/tests/dialects/test_clickhouse.py
index b6a7765..f5372d9 100644
--- a/tests/dialects/test_clickhouse.py
+++ b/tests/dialects/test_clickhouse.py
@@ -39,8 +39,17 @@ class TestClickhouse(Validator):
self.validate_identity(
"CREATE TABLE test (id UInt8) ENGINE=AggregatingMergeTree() ORDER BY tuple()"
)
+ self.validate_identity(
+ "CREATE TABLE test ON CLUSTER default (id UInt8) ENGINE=AggregatingMergeTree() ORDER BY tuple()"
+ )
+ self.validate_identity(
+ "CREATE MATERIALIZED VIEW test_view ON CLUSTER cl1 (id UInt8) ENGINE=AggregatingMergeTree() ORDER BY tuple() AS SELECT * FROM test_data"
+ )
self.validate_all(
+ r"'Enum8(\'Sunday\' = 0)'", write={"clickhouse": "'Enum8(''Sunday'' = 0)'"}
+ )
+ self.validate_all(
"SELECT uniq(x) FROM (SELECT any(y) AS x FROM (SELECT 1 AS y))",
read={
"bigquery": "SELECT APPROX_COUNT_DISTINCT(x) FROM (SELECT ANY_VALUE(y) x FROM (SELECT 1 y))",
@@ -395,3 +404,91 @@ SET
},
pretty=True,
)
+ self.validate_all(
+ """
+ CREATE DICTIONARY discounts_dict (
+ advertiser_id UInt64,
+ discount_start_date Date,
+ discount_end_date Date,
+ amount Float64
+ )
+ PRIMARY KEY id
+ SOURCE(CLICKHOUSE(TABLE 'discounts'))
+ LIFETIME(MIN 1 MAX 1000)
+ LAYOUT(RANGE_HASHED(range_lookup_strategy 'max'))
+ RANGE(MIN discount_start_date MAX discount_end_date)
+ """,
+ write={
+ "clickhouse": """CREATE DICTIONARY discounts_dict (
+ advertiser_id UInt64,
+ discount_start_date DATE,
+ discount_end_date DATE,
+ amount Float64
+)
+PRIMARY KEY (id)
+SOURCE(CLICKHOUSE(
+ TABLE 'discounts'
+))
+LIFETIME(MIN 1 MAX 1000)
+LAYOUT(RANGE_HASHED(
+ range_lookup_strategy 'max'
+))
+RANGE(MIN discount_start_date MAX discount_end_date)""",
+ },
+ pretty=True,
+ )
+ self.validate_all(
+ """
+ CREATE DICTIONARY my_ip_trie_dictionary (
+ prefix String,
+ asn UInt32,
+ cca2 String DEFAULT '??'
+ )
+ PRIMARY KEY prefix
+ SOURCE(CLICKHOUSE(TABLE 'my_ip_addresses'))
+ LAYOUT(IP_TRIE)
+ LIFETIME(3600);
+ """,
+ write={
+ "clickhouse": """CREATE DICTIONARY my_ip_trie_dictionary (
+ prefix TEXT,
+ asn UInt32,
+ cca2 TEXT DEFAULT '??'
+)
+PRIMARY KEY (prefix)
+SOURCE(CLICKHOUSE(
+ TABLE 'my_ip_addresses'
+))
+LAYOUT(IP_TRIE())
+LIFETIME(MIN 0 MAX 3600)""",
+ },
+ pretty=True,
+ )
+ self.validate_all(
+ """
+ CREATE DICTIONARY polygons_test_dictionary
+ (
+ key Array(Array(Array(Tuple(Float64, Float64)))),
+ name String
+ )
+ PRIMARY KEY key
+ SOURCE(CLICKHOUSE(TABLE 'polygons_test_table'))
+ LAYOUT(POLYGON(STORE_POLYGON_KEY_COLUMN 1))
+ LIFETIME(0);
+ """,
+ write={
+ "clickhouse": """CREATE DICTIONARY polygons_test_dictionary (
+ key Array(Array(Array(Tuple(Float64, Float64)))),
+ name TEXT
+)
+PRIMARY KEY (key)
+SOURCE(CLICKHOUSE(
+ TABLE 'polygons_test_table'
+))
+LAYOUT(POLYGON(
+ STORE_POLYGON_KEY_COLUMN 1
+))
+LIFETIME(MIN 0 MAX 0)""",
+ },
+ pretty=True,
+ )
diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py
index e144e81..7e20812 100644
--- a/tests/dialects/test_dialect.py
+++ b/tests/dialects/test_dialect.py
@@ -484,6 +484,7 @@ class TestDialect(Validator):
"bigquery": "CAST(x AS DATE)",
"duckdb": "CAST(x AS DATE)",
"hive": "TO_DATE(x)",
+ "postgres": "CAST(x AS DATE)",
"presto": "CAST(SUBSTR(CAST(x AS VARCHAR), 1, 10) AS DATE)",
"snowflake": "CAST(x AS DATE)",
},
diff --git a/tests/dialects/test_duckdb.py b/tests/dialects/test_duckdb.py
index ce6b122..ee15d04 100644
--- a/tests/dialects/test_duckdb.py
+++ b/tests/dialects/test_duckdb.py
@@ -132,6 +132,7 @@ class TestDuckDB(Validator):
parse_one("a // b", read="duckdb").assert_is(exp.IntDiv).sql(dialect="duckdb"), "a // b"
)
+ self.validate_identity("SELECT * FROM foo ASOF LEFT JOIN bar ON a = b")
self.validate_identity("PIVOT Cities ON Year USING SUM(Population)")
self.validate_identity("PIVOT Cities ON Year USING FIRST(Population)")
self.validate_identity("PIVOT Cities ON Year USING SUM(Population) GROUP BY Country")
diff --git a/tests/dialects/test_hive.py b/tests/dialects/test_hive.py
index 99b5602..f6cc224 100644
--- a/tests/dialects/test_hive.py
+++ b/tests/dialects/test_hive.py
@@ -412,6 +412,7 @@ class TestHive(Validator):
"SELECT 1_a AS a FROM test_table",
write={
"spark": "SELECT 1_a AS a FROM test_table",
+ "trino": 'SELECT "1_a" AS a FROM test_table',
},
)
self.validate_all(
diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py
index a80153b..4fb6fa5 100644
--- a/tests/dialects/test_mysql.py
+++ b/tests/dialects/test_mysql.py
@@ -121,7 +121,7 @@ class TestMySQL(Validator):
)
def test_canonical_functions(self):
- self.validate_identity("SELECT LEFT('str', 2)", "SELECT SUBSTRING('str', 1, 2)")
+ self.validate_identity("SELECT LEFT('str', 2)", "SELECT LEFT('str', 2)")
self.validate_identity("SELECT INSTR('str', 'substr')", "SELECT LOCATE('substr', 'str')")
self.validate_identity("SELECT UCASE('foo')", "SELECT UPPER('foo')")
self.validate_identity("SELECT LCASE('foo')", "SELECT LOWER('foo')")
diff --git a/tests/dialects/test_postgres.py b/tests/dialects/test_postgres.py
index 1f288c6..972a8c8 100644
--- a/tests/dialects/test_postgres.py
+++ b/tests/dialects/test_postgres.py
@@ -7,6 +7,7 @@ class TestPostgres(Validator):
dialect = "postgres"
def test_ddl(self):
+ self.validate_identity("CREATE TABLE public.y (x TSTZRANGE NOT NULL)")
self.validate_identity("CREATE TABLE test (foo HSTORE)")
self.validate_identity("CREATE TABLE test (foo JSONB)")
self.validate_identity("CREATE TABLE test (foo VARCHAR(64)[])")
@@ -85,6 +86,18 @@ class TestPostgres(Validator):
)
def test_postgres(self):
+ self.validate_identity("CAST(x AS INT4RANGE)")
+ self.validate_identity("CAST(x AS INT4MULTIRANGE)")
+ self.validate_identity("CAST(x AS INT8RANGE)")
+ self.validate_identity("CAST(x AS INT8MULTIRANGE)")
+ self.validate_identity("CAST(x AS NUMRANGE)")
+ self.validate_identity("CAST(x AS NUMMULTIRANGE)")
+ self.validate_identity("CAST(x AS TSRANGE)")
+ self.validate_identity("CAST(x AS TSMULTIRANGE)")
+ self.validate_identity("CAST(x AS TSTZRANGE)")
+ self.validate_identity("CAST(x AS TSTZMULTIRANGE)")
+ self.validate_identity("CAST(x AS DATERANGE)")
+ self.validate_identity("CAST(x AS DATEMULTIRANGE)")
self.validate_identity(
"""LAST_VALUE("col1") OVER (ORDER BY "col2" RANGE BETWEEN INTERVAL '1 day' PRECEDING AND '1 month' FOLLOWING)"""
)
diff --git a/tests/dialects/test_presto.py b/tests/dialects/test_presto.py
index 1f5953c..e3d09ef 100644
--- a/tests/dialects/test_presto.py
+++ b/tests/dialects/test_presto.py
@@ -7,6 +7,18 @@ class TestPresto(Validator):
def test_cast(self):
self.validate_all(
+ "SELECT DATE_DIFF('week', CAST(SUBSTR(CAST('2009-01-01' AS VARCHAR), 1, 10) AS DATE), CAST(SUBSTR(CAST('2009-12-31' AS VARCHAR), 1, 10) AS DATE))",
+ read={"redshift": "SELECT DATEDIFF(week, '2009-01-01', '2009-12-31')"},
+ )
+ self.validate_all(
+ "SELECT DATE_ADD('month', 18, CAST(SUBSTR(CAST('2008-02-28' AS VARCHAR), 1, 10) AS DATE))",
+ read={"redshift": "SELECT DATEADD(month, 18, '2008-02-28')"},
+ )
+ self.validate_all(
+ "SELECT TRY_CAST('1970-01-01 00:00:00' AS TIMESTAMP)",
+ read={"postgres": "SELECT 'epoch'::TIMESTAMP"},
+ )
+ self.validate_all(
"FROM_BASE64(x)",
read={
"hive": "UNBASE64(x)",
@@ -434,10 +446,17 @@ class TestPresto(Validator):
self.validate_identity("START TRANSACTION ISOLATION LEVEL REPEATABLE READ")
self.validate_identity("APPROX_PERCENTILE(a, b, c, d)")
+ self.validate_all("VALUES 1, 2, 3", write={"presto": "VALUES (1), (2), (3)"})
self.validate_all("INTERVAL '1 day'", write={"trino": "INTERVAL '1' day"})
self.validate_all("(5 * INTERVAL '7' day)", read={"": "INTERVAL '5' week"})
self.validate_all("(5 * INTERVAL '7' day)", read={"": "INTERVAL '5' WEEKS"})
self.validate_all(
+ "SELECT SUBSTRING(a, 1, 3), SUBSTRING(a, LENGTH(a) - (3 - 1))",
+ read={
+ "redshift": "SELECT LEFT(a, 3), RIGHT(a, 3)",
+ },
+ )
+ self.validate_all(
"WITH RECURSIVE t(n) AS (SELECT 1 AS n UNION ALL SELECT n + 1 AS n FROM t WHERE n < 4) SELECT SUM(n) FROM t",
read={
"postgres": "WITH RECURSIVE t AS (SELECT 1 AS n UNION ALL SELECT n + 1 AS n FROM t WHERE n < 4) SELECT SUM(n) FROM t",
diff --git a/tests/dialects/test_redshift.py b/tests/dialects/test_redshift.py
index 6707b7a..f4efe24 100644
--- a/tests/dialects/test_redshift.py
+++ b/tests/dialects/test_redshift.py
@@ -11,6 +11,16 @@ class TestRedshift(Validator):
self.validate_identity("$foo")
self.validate_all(
+ "SELECT STRTOL('abc', 16)",
+ read={
+ "trino": "SELECT FROM_BASE('abc', 16)",
+ },
+ write={
+ "redshift": "SELECT STRTOL('abc', 16)",
+ "trino": "SELECT FROM_BASE('abc', 16)",
+ },
+ )
+ self.validate_all(
"SELECT SNAPSHOT, type",
write={
"": "SELECT SNAPSHOT, type",
@@ -19,6 +29,35 @@ class TestRedshift(Validator):
)
self.validate_all(
+ "x is true",
+ write={
+ "redshift": "x IS TRUE",
+ "presto": "x",
+ },
+ )
+ self.validate_all(
+ "x is false",
+ write={
+ "redshift": "x IS FALSE",
+ "presto": "NOT x",
+ },
+ )
+ self.validate_all(
+ "x is not false",
+ write={
+ "redshift": "NOT x IS FALSE",
+ "presto": "NOT NOT x",
+ },
+ )
+ self.validate_all(
+ "LEN(x)",
+ write={
+ "redshift": "LENGTH(x)",
+ "presto": "LENGTH(x)",
+ },
+ )
+
+ self.validate_all(
"SELECT SYSDATE",
write={
"": "SELECT CURRENT_TIMESTAMP()",
@@ -141,7 +180,7 @@ class TestRedshift(Validator):
"DATEDIFF('day', a, b)",
write={
"redshift": "DATEDIFF(day, a, b)",
- "presto": "DATE_DIFF('day', a, b)",
+ "presto": "DATE_DIFF('day', CAST(SUBSTR(CAST(a AS VARCHAR), 1, 10) AS DATE), CAST(SUBSTR(CAST(b AS VARCHAR), 1, 10) AS DATE))",
},
)
self.validate_all(
diff --git a/tests/dialects/test_spark.py b/tests/dialects/test_spark.py
index bcfd984..7c8ca1b 100644
--- a/tests/dialects/test_spark.py
+++ b/tests/dialects/test_spark.py
@@ -5,6 +5,9 @@ class TestSpark(Validator):
dialect = "spark"
def test_ddl(self):
+ self.validate_identity("CREATE TABLE foo (col VARCHAR(50))")
+ self.validate_identity("CREATE TABLE foo (col STRUCT<struct_col_a: VARCHAR((50))>)")
+
self.validate_all(
"CREATE TABLE db.example_table (col_a struct<struct_col_a:int, struct_col_b:string>)",
write={
@@ -223,6 +226,20 @@ TBLPROPERTIES (
self.validate_identity("SPLIT(str, pattern, lim)")
self.validate_all(
+ "SELECT * FROM ((VALUES 1))", write={"spark": "SELECT * FROM (VALUES (1))"}
+ )
+ self.validate_all(
+ "SELECT CAST(STRUCT('fooo') AS STRUCT<a: VARCHAR(2)>)",
+ write={"spark": "SELECT CAST(STRUCT('fooo') AS STRUCT<a: STRING>)"},
+ )
+ self.validate_all(
+ "SELECT CAST(123456 AS VARCHAR(3))",
+ write={
+ "": "SELECT TRY_CAST(123456 AS TEXT)",
+ "spark": "SELECT CAST(123456 AS STRING)",
+ },
+ )
+ self.validate_all(
"SELECT piv.Q1 FROM (SELECT * FROM produce PIVOT(SUM(sales) FOR quarter IN ('Q1', 'Q2'))) AS piv",
read={
"snowflake": "SELECT piv.Q1 FROM produce PIVOT(SUM(sales) FOR quarter IN ('Q1', 'Q2')) piv",
@@ -368,10 +385,10 @@ TBLPROPERTIES (
self.validate_all(
"SELECT LEFT(x, 2), RIGHT(x, 2)",
write={
- "duckdb": "SELECT SUBSTRING(x, 1, 2), SUBSTRING(x, LENGTH(x) - 2 + 1, 2)",
- "presto": "SELECT SUBSTRING(x, 1, 2), SUBSTRING(x, LENGTH(x) - 2 + 1, 2)",
- "hive": "SELECT SUBSTRING(x, 1, 2), SUBSTRING(x, LENGTH(x) - 2 + 1, 2)",
- "spark": "SELECT SUBSTRING(x, 1, 2), SUBSTRING(x, LENGTH(x) - 2 + 1, 2)",
+ "duckdb": "SELECT LEFT(x, 2), RIGHT(x, 2)",
+ "presto": "SELECT SUBSTRING(x, 1, 2), SUBSTRING(x, LENGTH(x) - (2 - 1))",
+ "hive": "SELECT SUBSTRING(x, 1, 2), SUBSTRING(x, LENGTH(x) - (2 - 1))",
+ "spark": "SELECT LEFT(x, 2), RIGHT(x, 2)",
},
)
self.validate_all(
diff --git a/tests/dialects/test_teradata.py b/tests/dialects/test_teradata.py
index 03eb02f..9f789d0 100644
--- a/tests/dialects/test_teradata.py
+++ b/tests/dialects/test_teradata.py
@@ -40,6 +40,27 @@ class TestTeradata(Validator):
self.validate_identity(
"CREATE TABLE a (b INT) PARTITION BY RANGE_N(b BETWEEN *, 1 AND * EACH b) INDEX (a)"
)
+ self.validate_identity(
+ "CREATE TABLE a, NO FALLBACK PROTECTION, NO LOG, NO JOURNAL, CHECKSUM=ON, NO MERGEBLOCKRATIO, BLOCKCOMPRESSION=ALWAYS (a INT)"
+ )
+ self.validate_identity(
+ "CREATE TABLE a, NO FALLBACK PROTECTION, NO LOG, NO JOURNAL, CHECKSUM=ON, NO MERGEBLOCKRATIO, BLOCKCOMPRESSION=ALWAYS (a INT)"
+ )
+ self.validate_identity(
+ "CREATE TABLE a, WITH JOURNAL TABLE=x.y.z, CHECKSUM=OFF, MERGEBLOCKRATIO=1, DATABLOCKSIZE=10 KBYTES (a INT)"
+ )
+ self.validate_identity(
+ "CREATE TABLE a, BEFORE JOURNAL, AFTER JOURNAL, FREESPACE=1, DEFAULT DATABLOCKSIZE, BLOCKCOMPRESSION=DEFAULT (a INT)"
+ )
+ self.validate_identity(
+ "CREATE TABLE a, DUAL JOURNAL, DUAL AFTER JOURNAL, MERGEBLOCKRATIO=1 PERCENT, DATABLOCKSIZE=10 KILOBYTES (a INT)"
+ )
+ self.validate_identity(
+ "CREATE TABLE a, DUAL BEFORE JOURNAL, LOCAL AFTER JOURNAL, MAXIMUM DATABLOCKSIZE, BLOCKCOMPRESSION=AUTOTEMP(c1 INT) (a INT)"
+ )
+ self.validate_identity(
+ "CREATE VOLATILE MULTISET TABLE a, NOT LOCAL AFTER JOURNAL, FREESPACE=1 PERCENT, DATABLOCKSIZE=10 BYTES, WITH NO CONCURRENT ISOLATED LOADING FOR ALL (a INT)"
+ )
self.validate_all(
"""
diff --git a/tests/fixtures/identity.sql b/tests/fixtures/identity.sql
index 0a1e305..9fdddf1 100644
--- a/tests/fixtures/identity.sql
+++ b/tests/fixtures/identity.sql
@@ -160,6 +160,7 @@ CASE WHEN SUM(x) > 3 THEN 1 END OVER (PARTITION BY x)
SUM(ROW() OVER (PARTITION BY x))
SUM(ROW() OVER (PARTITION BY x + 1))
SUM(ROW() OVER (PARTITION BY x AND y))
+SUM(x) OVER (w ORDER BY y)
(ROW() OVER ())
CASE WHEN (x > 1) THEN 1 ELSE 0 END
CASE (1) WHEN 1 THEN 1 ELSE 0 END
@@ -570,14 +571,7 @@ CREATE TABLE foo (baz CHAR(4) CHARACTER SET LATIN UPPERCASE NOT CASESPECIFIC COM
CREATE TABLE foo (baz DATE FORMAT 'YYYY/MM/DD' TITLE 'title' INLINE LENGTH 1 COMPRESS ('a', 'b'))
CREATE TABLE t (title TEXT)
CREATE TABLE foo (baz INT, inline TEXT)
-CREATE TABLE a, FALLBACK, LOG, JOURNAL, CHECKSUM=DEFAULT, DEFAULT MERGEBLOCKRATIO, BLOCKCOMPRESSION=MANUAL (a INT)
-CREATE TABLE a, NO FALLBACK PROTECTION, NO LOG, NO JOURNAL, CHECKSUM=ON, NO MERGEBLOCKRATIO, BLOCKCOMPRESSION=ALWAYS (a INT)
-CREATE TABLE a, WITH JOURNAL TABLE=x.y.z, CHECKSUM=OFF, MERGEBLOCKRATIO=1, DATABLOCKSIZE=10 KBYTES (a INT)
-CREATE TABLE a, BEFORE JOURNAL, AFTER JOURNAL, FREESPACE=1, DEFAULT DATABLOCKSIZE, BLOCKCOMPRESSION=DEFAULT (a INT)
-CREATE TABLE a, DUAL JOURNAL, DUAL AFTER JOURNAL, MERGEBLOCKRATIO=1 PERCENT, DATABLOCKSIZE=10 KILOBYTES (a INT)
-CREATE TABLE a, DUAL BEFORE JOURNAL, LOCAL AFTER JOURNAL, MAXIMUM DATABLOCKSIZE, BLOCKCOMPRESSION=AUTOTEMP(c1 INT) (a INT)
CREATE SET GLOBAL TEMPORARY TABLE a, NO BEFORE JOURNAL, NO AFTER JOURNAL, MINIMUM DATABLOCKSIZE, BLOCKCOMPRESSION=NEVER (a INT)
-CREATE VOLATILE MULTISET TABLE a, NOT LOCAL AFTER JOURNAL, FREESPACE=1 PERCENT, DATABLOCKSIZE=10 BYTES, WITH NO CONCURRENT ISOLATED LOADING FOR ALL (a INT)
CREATE ALGORITHM=UNDEFINED DEFINER=foo@% SQL SECURITY DEFINER VIEW a AS (SELECT a FROM b)
CREATE TEMPORARY TABLE x AS SELECT a FROM d
CREATE TEMPORARY TABLE IF NOT EXISTS x AS SELECT a FROM d
diff --git a/tests/test_parser.py b/tests/test_parser.py
index 84ae0b5..897357f 100644
--- a/tests/test_parser.py
+++ b/tests/test_parser.py
@@ -521,3 +521,6 @@ class TestParser(unittest.TestCase):
self.assertEqual(
parse_one("create materialized table x").sql(), "CREATE MATERIALIZED TABLE x"
)
+
+ def test_parse_floats(self):
+ self.assertTrue(parse_one("1. ").is_number)
diff --git a/tests/test_serde.py b/tests/test_serde.py
index 6b5c989..1043fcf 100644
--- a/tests/test_serde.py
+++ b/tests/test_serde.py
@@ -27,7 +27,7 @@ class TestSerDe(unittest.TestCase):
self.assertEqual(before, after)
def test_type_annotations(self):
- before = annotate_types(parse_one("CAST('1' AS INT)"))
+ before = annotate_types(parse_one("CAST('1' AS STRUCT<x ARRAY<INT>>)"))
after = self.dump_load(before)
self.assertEqual(before.type, after.type)
self.assertEqual(before.this.type, after.this.type)