diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/dialects/test_bigquery.py | 25 | ||||
-rw-r--r-- | tests/dialects/test_dialect.py | 2 | ||||
-rw-r--r-- | tests/dialects/test_duckdb.py | 2 | ||||
-rw-r--r-- | tests/dialects/test_hive.py | 8 | ||||
-rw-r--r-- | tests/dialects/test_postgres.py | 1 | ||||
-rw-r--r-- | tests/dialects/test_presto.py | 2 | ||||
-rw-r--r-- | tests/dialects/test_spark.py | 7 | ||||
-rw-r--r-- | tests/dialects/test_sqlite.py | 2 | ||||
-rw-r--r-- | tests/dialects/test_tsql.py | 209 | ||||
-rw-r--r-- | tests/fixtures/identity.sql | 2 | ||||
-rw-r--r-- | tests/fixtures/optimizer/qualify_tables.sql | 74 | ||||
-rw-r--r-- | tests/test_expressions.py | 17 | ||||
-rw-r--r-- | tests/test_generator.py | 7 | ||||
-rw-r--r-- | tests/test_parser.py | 55 |
14 files changed, 391 insertions, 22 deletions
diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py index eac3cac..3939ba0 100644 --- a/tests/dialects/test_bigquery.py +++ b/tests/dialects/test_bigquery.py @@ -108,6 +108,27 @@ class TestBigQuery(Validator): self.validate_all("CAST(x AS TIMESTAMPTZ)", write={"bigquery": "CAST(x AS TIMESTAMP)"}) self.validate_all("CAST(x AS RECORD)", write={"bigquery": "CAST(x AS STRUCT)"}) self.validate_all( + "MD5(x)", + write={ + "": "MD5_DIGEST(x)", + "bigquery": "MD5(x)", + "hive": "UNHEX(MD5(x))", + "spark": "UNHEX(MD5(x))", + }, + ) + self.validate_all( + "SELECT TO_HEX(MD5(some_string))", + read={ + "duckdb": "SELECT MD5(some_string)", + "spark": "SELECT MD5(some_string)", + }, + write={ + "": "SELECT MD5(some_string)", + "bigquery": "SELECT TO_HEX(MD5(some_string))", + "duckdb": "SELECT MD5(some_string)", + }, + ) + self.validate_all( "SELECT CAST('20201225' AS TIMESTAMP FORMAT 'YYYYMMDD' AT TIME ZONE 'America/New_York')", write={"bigquery": "SELECT PARSE_TIMESTAMP('%Y%m%d', '20201225', 'America/New_York')"}, ) @@ -263,7 +284,7 @@ class TestBigQuery(Validator): "duckdb": "CAST(a AS BIGINT)", "presto": "CAST(a AS BIGINT)", "hive": "CAST(a AS BIGINT)", - "spark": "CAST(a AS LONG)", + "spark": "CAST(a AS BIGINT)", }, ) self.validate_all( @@ -413,7 +434,7 @@ class TestBigQuery(Validator): "duckdb": "CREATE TABLE db.example_table (col_a STRUCT(struct_col_a BIGINT, struct_col_b STRUCT(nested_col_a TEXT, nested_col_b TEXT)))", "presto": "CREATE TABLE db.example_table (col_a ROW(struct_col_a BIGINT, struct_col_b ROW(nested_col_a VARCHAR, nested_col_b VARCHAR)))", "hive": "CREATE TABLE db.example_table (col_a STRUCT<struct_col_a BIGINT, struct_col_b STRUCT<nested_col_a STRING, nested_col_b STRING>>)", - "spark": "CREATE TABLE db.example_table (col_a STRUCT<struct_col_a: LONG, struct_col_b: STRUCT<nested_col_a: STRING, nested_col_b: STRING>>)", + "spark": "CREATE TABLE db.example_table (col_a STRUCT<struct_col_a: BIGINT, struct_col_b: STRUCT<nested_col_a: STRING, nested_col_b: STRING>>)", }, ) self.validate_all( diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py index 21efc6b..05738cf 100644 --- a/tests/dialects/test_dialect.py +++ b/tests/dialects/test_dialect.py @@ -218,7 +218,7 @@ class TestDialect(Validator): "presto": "CAST(a AS SMALLINT)", "redshift": "CAST(a AS SMALLINT)", "snowflake": "CAST(a AS SMALLINT)", - "spark": "CAST(a AS SHORT)", + "spark": "CAST(a AS SMALLINT)", "sqlite": "CAST(a AS INTEGER)", "starrocks": "CAST(a AS SMALLINT)", }, diff --git a/tests/dialects/test_duckdb.py b/tests/dialects/test_duckdb.py index cad1c15..336f47d 100644 --- a/tests/dialects/test_duckdb.py +++ b/tests/dialects/test_duckdb.py @@ -532,7 +532,7 @@ class TestDuckDB(Validator): "duckdb": "CAST(COL AS BIGINT[])", "presto": "CAST(COL AS ARRAY(BIGINT))", "hive": "CAST(COL AS ARRAY<BIGINT>)", - "spark": "CAST(COL AS ARRAY<LONG>)", + "spark": "CAST(COL AS ARRAY<BIGINT>)", "postgres": "CAST(COL AS BIGINT[])", "snowflake": "CAST(COL AS ARRAY)", }, diff --git a/tests/dialects/test_hive.py b/tests/dialects/test_hive.py index c9bcf16..0503f6a 100644 --- a/tests/dialects/test_hive.py +++ b/tests/dialects/test_hive.py @@ -73,7 +73,7 @@ class TestHive(Validator): "duckdb": "TRY_CAST(1 AS SMALLINT)", "presto": "TRY_CAST(1 AS SMALLINT)", "hive": "CAST(1 AS SMALLINT)", - "spark": "CAST(1 AS SHORT)", + "spark": "CAST(1 AS SMALLINT)", }, ) self.validate_all( @@ -82,7 +82,7 @@ class TestHive(Validator): "duckdb": "TRY_CAST(1 AS SMALLINT)", "presto": "TRY_CAST(1 AS SMALLINT)", "hive": "CAST(1 AS SMALLINT)", - "spark": "CAST(1 AS SHORT)", + "spark": "CAST(1 AS SMALLINT)", }, ) self.validate_all( @@ -91,7 +91,7 @@ class TestHive(Validator): "duckdb": "TRY_CAST(1 AS TINYINT)", "presto": "TRY_CAST(1 AS TINYINT)", "hive": "CAST(1 AS TINYINT)", - "spark": "CAST(1 AS BYTE)", + "spark": "CAST(1 AS TINYINT)", }, ) self.validate_all( @@ -100,7 +100,7 @@ class TestHive(Validator): "duckdb": "TRY_CAST(1 AS BIGINT)", "presto": "TRY_CAST(1 AS BIGINT)", "hive": "CAST(1 AS BIGINT)", - "spark": "CAST(1 AS LONG)", + "spark": "CAST(1 AS BIGINT)", }, ) self.validate_all( diff --git a/tests/dialects/test_postgres.py b/tests/dialects/test_postgres.py index 052d4cc..605dfff 100644 --- a/tests/dialects/test_postgres.py +++ b/tests/dialects/test_postgres.py @@ -121,6 +121,7 @@ class TestPostgres(Validator): ) def test_postgres(self): + self.validate_identity("CAST(x AS MONEY)") self.validate_identity("CAST(x AS INT4RANGE)") self.validate_identity("CAST(x AS INT4MULTIRANGE)") self.validate_identity("CAST(x AS INT8RANGE)") diff --git a/tests/dialects/test_presto.py b/tests/dialects/test_presto.py index 45a0cd9..ddfa9e8 100644 --- a/tests/dialects/test_presto.py +++ b/tests/dialects/test_presto.py @@ -65,7 +65,7 @@ class TestPresto(Validator): "bigquery": "CAST([1, 2] AS ARRAY<INT64>)", "duckdb": "CAST(LIST_VALUE(1, 2) AS BIGINT[])", "presto": "CAST(ARRAY[1, 2] AS ARRAY(BIGINT))", - "spark": "CAST(ARRAY(1, 2) AS ARRAY<LONG>)", + "spark": "CAST(ARRAY(1, 2) AS ARRAY<BIGINT>)", "snowflake": "CAST([1, 2] AS ARRAY)", }, ) diff --git a/tests/dialects/test_spark.py b/tests/dialects/test_spark.py index 25841c5..32be23e 100644 --- a/tests/dialects/test_spark.py +++ b/tests/dialects/test_spark.py @@ -233,6 +233,13 @@ TBLPROPERTIES ( self.validate_identity("SPLIT(str, pattern, lim)") self.validate_all( + "UNHEX(MD5(x))", + write={ + "bigquery": "FROM_HEX(TO_HEX(MD5(x)))", + "spark": "UNHEX(MD5(x))", + }, + ) + self.validate_all( "SELECT * FROM ((VALUES 1))", write={"spark": "SELECT * FROM (VALUES (1))"} ) self.validate_all( diff --git a/tests/dialects/test_sqlite.py b/tests/dialects/test_sqlite.py index 10da9b0..4cf0832 100644 --- a/tests/dialects/test_sqlite.py +++ b/tests/dialects/test_sqlite.py @@ -95,7 +95,7 @@ class TestSQLite(Validator): "SELECT CAST([a].[b] AS SMALLINT) FROM foo", write={ "sqlite": 'SELECT CAST("a"."b" AS INTEGER) FROM foo', - "spark": "SELECT CAST(`a`.`b` AS SHORT) FROM foo", + "spark": "SELECT CAST(`a`.`b` AS SMALLINT) FROM foo", }, ) self.validate_all( diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py index ca6d70c..5426859 100644 --- a/tests/dialects/test_tsql.py +++ b/tests/dialects/test_tsql.py @@ -43,7 +43,7 @@ class TestTSQL(Validator): "SELECT CAST([a].[b] AS SMALLINT) FROM foo", write={ "tsql": 'SELECT CAST("a"."b" AS SMALLINT) FROM foo', - "spark": "SELECT CAST(`a`.`b` AS SHORT) FROM foo", + "spark": "SELECT CAST(`a`.`b` AS SMALLINT) FROM foo", }, ) self.validate_all( @@ -84,7 +84,7 @@ class TestTSQL(Validator): "SELECT CAST([a].[b] AS SMALLINT) FROM foo", write={ "tsql": 'SELECT CAST("a"."b" AS SMALLINT) FROM foo', - "spark": "SELECT CAST(`a`.`b` AS SHORT) FROM foo", + "spark": "SELECT CAST(`a`.`b` AS SMALLINT) FROM foo", }, ) self.validate_all( @@ -155,6 +155,211 @@ class TestTSQL(Validator): }, ) + def test__types_ints(self): + self.validate_all( + "CAST(X AS INT)", + write={ + "hive": "CAST(X AS INT)", + "spark2": "CAST(X AS INT)", + "spark": "CAST(X AS INT)", + "tsql": "CAST(X AS INTEGER)", + }, + ) + + self.validate_all( + "CAST(X AS BIGINT)", + write={ + "hive": "CAST(X AS BIGINT)", + "spark2": "CAST(X AS BIGINT)", + "spark": "CAST(X AS BIGINT)", + "tsql": "CAST(X AS BIGINT)", + }, + ) + + self.validate_all( + "CAST(X AS SMALLINT)", + write={ + "hive": "CAST(X AS SMALLINT)", + "spark2": "CAST(X AS SMALLINT)", + "spark": "CAST(X AS SMALLINT)", + "tsql": "CAST(X AS SMALLINT)", + }, + ) + + self.validate_all( + "CAST(X AS TINYINT)", + write={ + "hive": "CAST(X AS TINYINT)", + "spark2": "CAST(X AS TINYINT)", + "spark": "CAST(X AS TINYINT)", + "tsql": "CAST(X AS TINYINT)", + }, + ) + + def test_types_decimals(self): + self.validate_all( + "CAST(x as FLOAT)", + write={ + "spark": "CAST(x AS FLOAT)", + "tsql": "CAST(x AS FLOAT)", + }, + ) + + self.validate_all( + "CAST(x as DOUBLE)", + write={ + "spark": "CAST(x AS DOUBLE)", + "tsql": "CAST(x AS DOUBLE)", + }, + ) + + self.validate_all( + "CAST(x as DECIMAL(15, 4))", + write={ + "spark": "CAST(x AS DECIMAL(15, 4))", + "tsql": "CAST(x AS NUMERIC(15, 4))", + }, + ) + + self.validate_all( + "CAST(x as NUMERIC(13,3))", + write={ + "spark": "CAST(x AS DECIMAL(13, 3))", + "tsql": "CAST(x AS NUMERIC(13, 3))", + }, + ) + + self.validate_all( + "CAST(x as MONEY)", + write={ + "spark": "CAST(x AS DECIMAL(15, 4))", + "tsql": "CAST(x AS MONEY)", + }, + ) + + self.validate_all( + "CAST(x as SMALLMONEY)", + write={ + "spark": "CAST(x AS DECIMAL(6, 4))", + "tsql": "CAST(x AS SMALLMONEY)", + }, + ) + + self.validate_all( + "CAST(x as REAL)", + write={ + "spark": "CAST(x AS FLOAT)", + "tsql": "CAST(x AS FLOAT)", + }, + ) + + def test_types_string(self): + self.validate_all( + "CAST(x as CHAR(1))", + write={ + "spark": "CAST(x AS CHAR(1))", + "tsql": "CAST(x AS CHAR(1))", + }, + ) + + self.validate_all( + "CAST(x as VARCHAR(2))", + write={ + "spark": "CAST(x AS VARCHAR(2))", + "tsql": "CAST(x AS VARCHAR(2))", + }, + ) + + self.validate_all( + "CAST(x as NCHAR(1))", + write={ + "spark": "CAST(x AS CHAR(1))", + "tsql": "CAST(x AS CHAR(1))", + }, + ) + + self.validate_all( + "CAST(x as NVARCHAR(2))", + write={ + "spark": "CAST(x AS VARCHAR(2))", + "tsql": "CAST(x AS VARCHAR(2))", + }, + ) + + def test_types_date(self): + self.validate_all( + "CAST(x as DATE)", + write={ + "spark": "CAST(x AS DATE)", + "tsql": "CAST(x AS DATE)", + }, + ) + + self.validate_all( + "CAST(x as DATE)", + write={ + "spark": "CAST(x AS DATE)", + "tsql": "CAST(x AS DATE)", + }, + ) + + self.validate_all( + "CAST(x as TIME(4))", + write={ + "spark": "CAST(x AS TIMESTAMP)", + "tsql": "CAST(x AS TIMESTAMP(4))", + }, + ) + + self.validate_all( + "CAST(x as DATETIME2)", + write={ + "spark": "CAST(x AS TIMESTAMP)", + "tsql": "CAST(x AS DATETIME2)", + }, + ) + + self.validate_all( + "CAST(x as DATETIMEOFFSET)", + write={ + "spark": "CAST(x AS TIMESTAMP)", + "tsql": "CAST(x AS TIMESTAMPTZ)", + }, + ) + + self.validate_all( + "CAST(x as SMALLDATETIME)", + write={ + "spark": "CAST(x AS TIMESTAMP)", + "tsql": "CAST(x AS DATETIME2)", + }, + ) + + def test_types_bin(self): + self.validate_all( + "CAST(x as BIT)", + write={ + "spark": "CAST(x AS BOOLEAN)", + "tsql": "CAST(x AS BIT)", + }, + ) + + self.validate_all( + "CAST(x as UNIQUEIDENTIFIER)", + write={ + "spark": "CAST(x AS STRING)", + "tsql": "CAST(x AS UNIQUEIDENTIFIER)", + }, + ) + + self.validate_all( + "CAST(x as VARBINARY)", + write={ + "spark": "CAST(x AS BINARY)", + "tsql": "CAST(x AS VARBINARY)", + }, + ) + def test_udf(self): self.validate_identity( "CREATE PROCEDURE foo @a INTEGER, @b INTEGER AS SELECT @a = SUM(bla) FROM baz AS bar" diff --git a/tests/fixtures/identity.sql b/tests/fixtures/identity.sql index 162d627..63631c4 100644 --- a/tests/fixtures/identity.sql +++ b/tests/fixtures/identity.sql @@ -423,7 +423,6 @@ SELECT 1 INTERSECT SELECT 2 SELECT 1 INTERSECT SELECT 2 SELECT 1 AS delete, 2 AS alter SELECT * FROM (x) -SELECT * FROM ((x)) SELECT * FROM ((SELECT 1)) SELECT * FROM (x CROSS JOIN foo LATERAL VIEW EXPLODE(y)) SELECT * FROM (SELECT 1) AS x @@ -838,3 +837,4 @@ SELECT * FROM schema.case SELECT * FROM current_date SELECT * FROM schema.current_date SELECT /*+ SOME_HINT(foo) */ 1 +SELECT * FROM (tbl1 CROSS JOIN (SELECT * FROM tbl2) AS t1) diff --git a/tests/fixtures/optimizer/qualify_tables.sql b/tests/fixtures/optimizer/qualify_tables.sql index 24d1b65..d8ce2b0 100644 --- a/tests/fixtures/optimizer/qualify_tables.sql +++ b/tests/fixtures/optimizer/qualify_tables.sql @@ -1,43 +1,103 @@ +# title: single table SELECT 1 FROM z; SELECT 1 FROM c.db.z AS z; +# title: single table with db SELECT 1 FROM y.z; SELECT 1 FROM c.y.z AS z; +# title: single table with db, catalog SELECT 1 FROM x.y.z; SELECT 1 FROM x.y.z AS z; +# title: single table with db, catalog, alias SELECT 1 FROM x.y.z AS z; SELECT 1 FROM x.y.z AS z; +# title: cte can't be qualified WITH a AS (SELECT 1 FROM z) SELECT 1 FROM a; WITH a AS (SELECT 1 FROM c.db.z AS z) SELECT 1 FROM a; +# title: query that yields a single column as projection SELECT (SELECT y.c FROM y AS y) FROM x; SELECT (SELECT y.c FROM c.db.y AS y) FROM c.db.x AS x; +# title: pivoted table SELECT * FROM x PIVOT (SUM(a) FOR b IN ('a', 'b')); SELECT * FROM c.db.x AS x PIVOT(SUM(a) FOR b IN ('a', 'b')) AS _q_0; ----------------------------- --- Expand join constructs ----------------------------- +----------------------------------------------------------- +--- Unnest wrapped tables / joins, expand join constructs +----------------------------------------------------------- --- This is valid in Trino, so we treat the (tbl AS tbl) as a "join construct" per postgres' terminology. -SELECT * FROM (tbl AS tbl) AS _q_0; -SELECT * FROM (SELECT * FROM c.db.tbl AS tbl) AS _q_0; +# title: wrapped table without alias +SELECT * FROM (tbl); +SELECT * FROM c.db.tbl AS tbl; + +# title: wrapped table with alias +SELECT * FROM (tbl AS tbl); +SELECT * FROM c.db.tbl AS tbl; + +# title: wrapped table with alias and multiple redundant parentheses +SELECT * FROM ((((tbl AS tbl)))); +SELECT * FROM c.db.tbl AS tbl; + +# title: chained wrapped joins without aliases (1) +SELECT * FROM ((a CROSS JOIN b) CROSS JOIN c); +SELECT * FROM c.db.a AS a CROSS JOIN c.db.b AS b CROSS JOIN c.db.c AS c; + +# title: chained wrapped joins without aliases (2) +SELECT * FROM (a CROSS JOIN (b CROSS JOIN c)); +SELECT * FROM c.db.a AS a CROSS JOIN c.db.b AS b CROSS JOIN c.db.c AS c; + +# title: chained wrapped joins without aliases (3) +SELECT * FROM ((a CROSS JOIN ((b CROSS JOIN c) CROSS JOIN d))); +SELECT * FROM c.db.a AS a CROSS JOIN c.db.b AS b CROSS JOIN c.db.c AS c CROSS JOIN c.db.d AS d; + +# title: chained wrapped joins without aliases (4) +SELECT * FROM ((a CROSS JOIN ((b CROSS JOIN c) CROSS JOIN (d CROSS JOIN e)))); +SELECT * FROM c.db.a AS a CROSS JOIN c.db.b AS b CROSS JOIN c.db.c AS c CROSS JOIN c.db.d AS d CROSS JOIN c.db.e AS e; + +# title: chained wrapped joins with aliases +SELECT * FROM ((a AS foo CROSS JOIN b AS bar) CROSS JOIN c AS baz); +SELECT * FROM c.db.a AS foo CROSS JOIN c.db.b AS bar CROSS JOIN c.db.c AS baz; + +# title: wrapped join with subquery without alias +SELECT * FROM (tbl1 CROSS JOIN (SELECT * FROM tbl2) AS t1); +SELECT * FROM c.db.tbl1 AS tbl1 CROSS JOIN (SELECT * FROM c.db.tbl2 AS tbl2) AS t1; + +# title: wrapped join with subquery with alias, parentheses can't be omitted because of alias +SELECT * FROM (tbl1 CROSS JOIN (SELECT * FROM tbl2) AS t1) AS t2; +SELECT * FROM (SELECT * FROM c.db.tbl1 AS tbl1 CROSS JOIN (SELECT * FROM c.db.tbl2 AS tbl2) AS t1) AS t2; + +# title: join construct as the right operand of a left join +SELECT * FROM a LEFT JOIN (b INNER JOIN c ON c.id = b.id) ON b.id = a.id; +SELECT * FROM c.db.a AS a LEFT JOIN c.db.b AS b ON b.id = a.id INNER JOIN c.db.c AS c ON c.id = b.id; + +# title: nested joins converted to canonical form +SELECT * FROM a LEFT JOIN b INNER JOIN c ON c.id = b.id ON b.id = a.id; +SELECT * FROM c.db.a AS a LEFT JOIN c.db.b AS b ON b.id = a.id INNER JOIN c.db.c AS c ON c.id = b.id; + +# title: parentheses can't be omitted because alias shadows inner table names +SELECT t.a FROM (tbl AS tbl) AS t; +SELECT t.a FROM (SELECT * FROM c.db.tbl AS tbl) AS t; +# title: outermost set of parentheses can't be omitted due to shadowing (1) SELECT * FROM ((tbl AS tbl)) AS _q_0; SELECT * FROM (SELECT * FROM c.db.tbl AS tbl) AS _q_0; -SELECT * FROM (((tbl AS tbl))) AS _q_0; +# title: outermost set of parentheses can't be omitted due to shadowing (2) +SELECT * FROM ((((tbl AS tbl)))) AS _q_0; SELECT * FROM (SELECT * FROM c.db.tbl AS tbl) AS _q_0; +# title: join construct with three tables in canonical form SELECT * FROM (tbl1 AS tbl1 JOIN tbl2 AS tbl2 ON id1 = id2 JOIN tbl3 AS tbl3 ON id1 = id3) AS _q_0; SELECT * FROM (SELECT * FROM c.db.tbl1 AS tbl1 JOIN c.db.tbl2 AS tbl2 ON id1 = id2 JOIN c.db.tbl3 AS tbl3 ON id1 = id3) AS _q_0; +# title: join construct with three tables in canonical form and redundant set of parentheses SELECT * FROM ((tbl1 AS tbl1 JOIN tbl2 AS tbl2 ON id1 = id2 JOIN tbl3 AS tbl3 ON id1 = id3)) AS _q_0; SELECT * FROM (SELECT * FROM c.db.tbl1 AS tbl1 JOIN c.db.tbl2 AS tbl2 ON id1 = id2 JOIN c.db.tbl3 AS tbl3 ON id1 = id3) AS _q_0; +# title: nested join construct in canonical form SELECT * FROM (tbl1 AS tbl1 JOIN (tbl2 AS tbl2 JOIN tbl3 AS tbl3 ON id2 = id3) AS _q_0 ON id1 = id3) AS _q_1; SELECT * FROM (SELECT * FROM c.db.tbl1 AS tbl1 JOIN (SELECT * FROM c.db.tbl2 AS tbl2 JOIN c.db.tbl3 AS tbl3 ON id2 = id3) AS _q_0 ON id1 = id3) AS _q_1; diff --git a/tests/test_expressions.py b/tests/test_expressions.py index f050c0b..277bec1 100644 --- a/tests/test_expressions.py +++ b/tests/test_expressions.py @@ -442,12 +442,15 @@ class TestExpressions(unittest.TestCase): expression.find(exp.Table).replace(parse_one("y")) self.assertEqual(expression.sql(), "SELECT c, b FROM y") - def test_pop(self): + def test_arg_deletion(self): + # Using the pop helper method expression = parse_one("SELECT a, b FROM x") expression.find(exp.Column).pop() self.assertEqual(expression.sql(), "SELECT b FROM x") + expression.find(exp.Column).pop() self.assertEqual(expression.sql(), "SELECT FROM x") + expression.pop() self.assertEqual(expression.sql(), "SELECT FROM x") @@ -455,6 +458,15 @@ class TestExpressions(unittest.TestCase): expression.find(exp.With).pop() self.assertEqual(expression.sql(), "SELECT * FROM x") + # Manually deleting by setting to None + expression = parse_one("SELECT * FROM foo JOIN bar") + self.assertEqual(len(expression.args.get("joins", [])), 1) + + expression.set("joins", None) + self.assertEqual(expression.sql(), "SELECT * FROM foo") + self.assertEqual(expression.args.get("joins", []), []) + self.assertIsNone(expression.args.get("joins")) + def test_walk(self): expression = parse_one("SELECT * FROM (SELECT * FROM x)") self.assertEqual(len(list(expression.walk())), 9) @@ -539,6 +551,9 @@ class TestExpressions(unittest.TestCase): self.assertIsInstance(parse_one("ARRAY(time, foo)"), exp.Array) self.assertIsInstance(parse_one("STANDARD_HASH('hello', 'sha256')"), exp.StandardHash) self.assertIsInstance(parse_one("DATE(foo)"), exp.Date) + self.assertIsInstance(parse_one("HEX(foo)"), exp.Hex) + self.assertIsInstance(parse_one("TO_HEX(foo)", read="bigquery"), exp.Hex) + self.assertIsInstance(parse_one("TO_HEX(MD5(foo))", read="bigquery"), exp.MD5) def test_column(self): column = parse_one("a.b.c.d") diff --git a/tests/test_generator.py b/tests/test_generator.py index fce5c81..ec90646 100644 --- a/tests/test_generator.py +++ b/tests/test_generator.py @@ -1,6 +1,6 @@ import unittest -from sqlglot import parse_one +from sqlglot import exp, parse_one from sqlglot.expressions import Func from sqlglot.parser import Parser from sqlglot.tokens import Tokenizer @@ -30,6 +30,11 @@ class TestGenerator(unittest.TestCase): expression = NewParser().parse(tokens)[0] self.assertEqual(expression.sql(), "SELECT SPECIAL_UDF(a, b, c, d + 1) FROM x") + self.assertEqual( + exp.DateTrunc(this=exp.to_column("event_date"), unit=exp.var("MONTH")).sql(), + "DATE_TRUNC(MONTH, event_date)", + ) + def test_identify(self): assert parse_one("x").sql(identify=True) == '"x"' assert parse_one("x").sql(identify="always") == '"x"' diff --git a/tests/test_parser.py b/tests/test_parser.py index 2fa6a09..891dcef 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -580,3 +580,58 @@ class TestParser(unittest.TestCase): def test_parse_floats(self): self.assertTrue(parse_one("1. ").is_number) + + def test_parse_wrapped_tables(self): + expr = parse_one("select * from (table)") + self.assertIsInstance(expr.args["from"].this, exp.Table) + self.assertTrue(expr.args["from"].this.args["wrapped"]) + + expr = parse_one("select * from (((table)))") + self.assertIsInstance(expr.args["from"].this, exp.Table) + self.assertTrue(expr.args["from"].this.args["wrapped"]) + + self.assertEqual(expr.sql(), "SELECT * FROM (table)") + + expr = parse_one("select * from (tbl1 join tbl2)") + self.assertIsInstance(expr.args["from"].this, exp.Table) + self.assertTrue(expr.args["from"].this.args["wrapped"]) + self.assertEqual(len(expr.args["from"].this.args["joins"]), 1) + + expr = parse_one("select * from (tbl1 join tbl2) t") + self.assertIsInstance(expr.args["from"].this, exp.Subquery) + self.assertIsInstance(expr.args["from"].this.this, exp.Select) + self.assertEqual(expr.sql(), "SELECT * FROM (SELECT * FROM tbl1, tbl2) AS t") + + expr = parse_one("select * from (tbl as tbl) t") + self.assertEqual(expr.sql(), "SELECT * FROM (SELECT * FROM tbl AS tbl) AS t") + + expr = parse_one("select * from ((a cross join b) cross join c)") + self.assertIsInstance(expr.args["from"].this, exp.Table) + self.assertTrue(expr.args["from"].this.args["wrapped"]) + self.assertEqual(len(expr.args["from"].this.args["joins"]), 2) + self.assertEqual(expr.sql(), "SELECT * FROM (a CROSS JOIN b CROSS JOIN c)") + + expr = parse_one("select * from ((a cross join b) cross join c) t") + self.assertIsInstance(expr.args["from"].this, exp.Subquery) + self.assertEqual(len(expr.args["from"].this.this.args["joins"]), 2) + self.assertEqual( + expr.sql(), "SELECT * FROM (SELECT * FROM a CROSS JOIN b CROSS JOIN c) AS t" + ) + + expr = parse_one("select * from (a cross join (b cross join c))") + self.assertIsInstance(expr.args["from"].this, exp.Table) + self.assertTrue(expr.args["from"].this.args["wrapped"]) + self.assertEqual(len(expr.args["from"].this.args["joins"]), 1) + self.assertIsInstance(expr.args["from"].this.args["joins"][0].this, exp.Table) + self.assertTrue(expr.args["from"].this.args["joins"][0].this.args["wrapped"]) + self.assertEqual(expr.sql(), "SELECT * FROM (a CROSS JOIN (b CROSS JOIN c))") + + expr = parse_one("select * from ((a cross join ((b cross join c) cross join d)))") + self.assertEqual(expr.sql(), "SELECT * FROM (a CROSS JOIN (b CROSS JOIN c CROSS JOIN d))") + + expr = parse_one( + "select * from ((a cross join ((b cross join c) cross join (d cross join e))))" + ) + self.assertEqual( + expr.sql(), "SELECT * FROM (a CROSS JOIN (b CROSS JOIN c CROSS JOIN (d CROSS JOIN e)))" + ) |