From a4bd8fff8aada95286f9b21ce5e1aaa852298625 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Fri, 23 Sep 2022 19:07:16 +0200 Subject: Merging upstream version 6.2.0. Signed-off-by: Daniel Baumann --- tests/dialects/test_dialect.py | 5 +++++ tests/dialects/test_duckdb.py | 3 +++ tests/dialects/test_hive.py | 2 +- tests/dialects/test_oracle.py | 6 ++++++ tests/dialects/test_presto.py | 4 ++-- tests/dialects/test_snowflake.py | 45 ++++++++++++++++++++++++++++++++++++++++ tests/dialects/test_spark.py | 12 ++--------- tests/dialects/test_tsql.py | 26 +++++++++++++++++++++++ tests/test_expressions.py | 32 +++++++++++++++++++++++++--- tests/test_parser.py | 3 +++ 10 files changed, 122 insertions(+), 16 deletions(-) create mode 100644 tests/dialects/test_oracle.py create mode 100644 tests/dialects/test_tsql.py (limited to 'tests') diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py index 6b7bfd3..4e0a3c6 100644 --- a/tests/dialects/test_dialect.py +++ b/tests/dialects/test_dialect.py @@ -228,6 +228,7 @@ class TestDialect(Validator): "duckdb": "STRPTIME(x, '%Y-%m-%dT%H:%M:%S')", "hive": "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'yyyy-MM-ddTHH:mm:ss')) AS TIMESTAMP)", "presto": "DATE_PARSE(x, '%Y-%m-%dT%H:%i:%S')", + "redshift": "TO_TIMESTAMP(x, 'YYYY-MM-DDTHH:MI:SS')", "spark": "TO_TIMESTAMP(x, 'yyyy-MM-ddTHH:mm:ss')", }, ) @@ -237,6 +238,7 @@ class TestDialect(Validator): "duckdb": "STRPTIME('2020-01-01', '%Y-%m-%d')", "hive": "CAST('2020-01-01' AS TIMESTAMP)", "presto": "DATE_PARSE('2020-01-01', '%Y-%m-%d')", + "redshift": "TO_TIMESTAMP('2020-01-01', 'YYYY-MM-DD')", "spark": "TO_TIMESTAMP('2020-01-01', 'yyyy-MM-dd')", }, ) @@ -246,6 +248,7 @@ class TestDialect(Validator): "duckdb": "STRPTIME(x, '%y')", "hive": "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'yy')) AS TIMESTAMP)", "presto": "DATE_PARSE(x, '%y')", + "redshift": "TO_TIMESTAMP(x, 'YY')", "spark": "TO_TIMESTAMP(x, 'yy')", }, ) @@ -287,6 +290,7 @@ class TestDialect(Validator): "duckdb": "STRFTIME(x, '%Y-%m-%d')", "hive": "DATE_FORMAT(x, 'yyyy-MM-dd')", "presto": "DATE_FORMAT(x, '%Y-%m-%d')", + "redshift": "TO_CHAR(x, 'YYYY-MM-DD')", }, ) self.validate_all( @@ -295,6 +299,7 @@ class TestDialect(Validator): "duckdb": "CAST(x AS TEXT)", "hive": "CAST(x AS STRING)", "presto": "CAST(x AS VARCHAR)", + "redshift": "CAST(x AS TEXT)", }, ) self.validate_all( diff --git a/tests/dialects/test_duckdb.py b/tests/dialects/test_duckdb.py index 501301f..f52decb 100644 --- a/tests/dialects/test_duckdb.py +++ b/tests/dialects/test_duckdb.py @@ -66,6 +66,9 @@ class TestDuckDB(Validator): def test_duckdb(self): self.validate_all( "LIST_VALUE(0, 1, 2)", + read={ + "spark": "ARRAY(0, 1, 2)", + }, write={ "bigquery": "[0, 1, 2]", "duckdb": "LIST_VALUE(0, 1, 2)", diff --git a/tests/dialects/test_hive.py b/tests/dialects/test_hive.py index 55086e3..a9b5168 100644 --- a/tests/dialects/test_hive.py +++ b/tests/dialects/test_hive.py @@ -131,7 +131,7 @@ class TestHive(Validator): write={ "presto": "CREATE TABLE test WITH (FORMAT = 'parquet', x = '1', Z = '2') AS SELECT 1", "hive": "CREATE TABLE test STORED AS PARQUET TBLPROPERTIES ('x' = '1', 'Z' = '2') AS SELECT 1", - "spark": "CREATE TABLE test STORED AS PARQUET TBLPROPERTIES ('x' = '1', 'Z' = '2') AS SELECT 1", + "spark": "CREATE TABLE test USING PARQUET TBLPROPERTIES ('x' = '1', 'Z' = '2') AS SELECT 1", }, ) self.validate_all( diff --git a/tests/dialects/test_oracle.py b/tests/dialects/test_oracle.py new file mode 100644 index 0000000..1fadb84 --- /dev/null +++ b/tests/dialects/test_oracle.py @@ -0,0 +1,6 @@ +from tests.dialects.test_dialect import Validator + + +class TestOracle(Validator): + def test_oracle(self): + self.validate_identity("SELECT * FROM V$SESSION") diff --git a/tests/dialects/test_presto.py b/tests/dialects/test_presto.py index eb9aa5c..96c299d 100644 --- a/tests/dialects/test_presto.py +++ b/tests/dialects/test_presto.py @@ -173,7 +173,7 @@ class TestPresto(Validator): write={ "presto": "CREATE TABLE test WITH (FORMAT = 'PARQUET') AS SELECT 1", "hive": "CREATE TABLE test STORED AS PARQUET AS SELECT 1", - "spark": "CREATE TABLE test STORED AS PARQUET AS SELECT 1", + "spark": "CREATE TABLE test USING PARQUET AS SELECT 1", }, ) self.validate_all( @@ -181,7 +181,7 @@ class TestPresto(Validator): write={ "presto": "CREATE TABLE test WITH (FORMAT = 'PARQUET', X = '1', Z = '2') AS SELECT 1", "hive": "CREATE TABLE test STORED AS PARQUET TBLPROPERTIES ('X' = '1', 'Z' = '2') AS SELECT 1", - "spark": "CREATE TABLE test STORED AS PARQUET TBLPROPERTIES ('X' = '1', 'Z' = '2') AS SELECT 1", + "spark": "CREATE TABLE test USING PARQUET TBLPROPERTIES ('X' = '1', 'Z' = '2') AS SELECT 1", }, ) self.validate_all( diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py index 2eeff52..165f8e2 100644 --- a/tests/dialects/test_snowflake.py +++ b/tests/dialects/test_snowflake.py @@ -175,3 +175,48 @@ class TestSnowflake(Validator): "snowflake": r"SELECT FIRST_VALUE(TABLE1.COLUMN1) IGNORE NULLS OVER (PARTITION BY RANDOM_COLUMN1, RANDOM_COLUMN2 ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS MY_ALIAS FROM TABLE1" }, ) + + def test_timestamps(self): + self.validate_all( + "SELECT CAST(a AS TIMESTAMP)", + write={ + "snowflake": "SELECT CAST(a AS TIMESTAMPNTZ)", + }, + ) + self.validate_all( + "SELECT a::TIMESTAMP_LTZ(9)", + write={ + "snowflake": "SELECT CAST(a AS TIMESTAMPLTZ(9))", + }, + ) + self.validate_all( + "SELECT a::TIMESTAMPLTZ", + write={ + "snowflake": "SELECT CAST(a AS TIMESTAMPLTZ)", + }, + ) + self.validate_all( + "SELECT a::TIMESTAMP WITH LOCAL TIME ZONE", + write={ + "snowflake": "SELECT CAST(a AS TIMESTAMPLTZ)", + }, + ) + self.validate_identity("SELECT EXTRACT(month FROM a)") + self.validate_all( + "SELECT EXTRACT('month', a)", + write={ + "snowflake": "SELECT EXTRACT('month' FROM a)", + }, + ) + self.validate_all( + "SELECT DATE_PART('month', a)", + write={ + "snowflake": "SELECT EXTRACT('month' FROM a)", + }, + ) + self.validate_all( + "SELECT DATE_PART(month FROM a::DATETIME)", + write={ + "snowflake": "SELECT EXTRACT(month FROM CAST(a AS DATETIME))", + }, + ) diff --git a/tests/dialects/test_spark.py b/tests/dialects/test_spark.py index 8794fed..22f6947 100644 --- a/tests/dialects/test_spark.py +++ b/tests/dialects/test_spark.py @@ -44,15 +44,7 @@ class TestSpark(Validator): write={ "presto": "CREATE TABLE test WITH (FORMAT = 'PARQUET') AS SELECT 1", "hive": "CREATE TABLE test STORED AS PARQUET AS SELECT 1", - "spark": "CREATE TABLE test STORED AS PARQUET AS SELECT 1", - }, - ) - self.validate_all( - "CREATE TABLE test USING ICEBERG STORED AS PARQUET AS SELECT 1", - write={ - "presto": "CREATE TABLE test WITH (TABLE_FORMAT = 'ICEBERG', FORMAT = 'PARQUET') AS SELECT 1", - "hive": "CREATE TABLE test USING ICEBERG STORED AS PARQUET AS SELECT 1", - "spark": "CREATE TABLE test USING ICEBERG STORED AS PARQUET AS SELECT 1", + "spark": "CREATE TABLE test USING PARQUET AS SELECT 1", }, ) self.validate_all( @@ -86,7 +78,7 @@ COMMENT 'Test comment: blah' PARTITIONED BY ( date STRING ) -STORED AS ICEBERG +USING ICEBERG TBLPROPERTIES ( 'x' = '1' )""", diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py new file mode 100644 index 0000000..0619eaa --- /dev/null +++ b/tests/dialects/test_tsql.py @@ -0,0 +1,26 @@ +from tests.dialects.test_dialect import Validator + + +class TestTSQL(Validator): + dialect = "tsql" + + def test_tsql(self): + self.validate_identity('SELECT "x"."y" FROM foo') + + self.validate_all( + "SELECT CAST([a].[b] AS SMALLINT) FROM foo", + write={ + "tsql": 'SELECT CAST("a"."b" AS SMALLINT) FROM foo', + "spark": "SELECT CAST(`a`.`b` AS SHORT) FROM foo", + }, + ) + + def test_types(self): + self.validate_identity("CAST(x AS XML)") + self.validate_identity("CAST(x AS UNIQUEIDENTIFIER)") + self.validate_identity("CAST(x AS MONEY)") + self.validate_identity("CAST(x AS SMALLMONEY)") + self.validate_identity("CAST(x AS ROWVERSION)") + self.validate_identity("CAST(x AS IMAGE)") + self.validate_identity("CAST(x AS SQL_VARIANT)") + self.validate_identity("CAST(x AS BIT)") diff --git a/tests/test_expressions.py b/tests/test_expressions.py index 716e457..59d584c 100644 --- a/tests/test_expressions.py +++ b/tests/test_expressions.py @@ -224,9 +224,6 @@ class TestExpressions(unittest.TestCase): self.assertEqual(actual_expression_2.sql(dialect="presto"), "IF(c - 2 > 0, c - 2, b)") self.assertIs(actual_expression_2, expression) - with self.assertRaises(ValueError): - parse_one("a").transform(lambda n: None) - def test_transform_no_infinite_recursion(self): expression = parse_one("a") @@ -247,6 +244,35 @@ class TestExpressions(unittest.TestCase): self.assertEqual(expression.transform(fun).sql(), "SELECT a, b FROM x") + def test_transform_node_removal(self): + expression = parse_one("SELECT a, b FROM x") + + def remove_column_b(node): + if isinstance(node, exp.Column) and node.name == "b": + return None + return node + + self.assertEqual(expression.transform(remove_column_b).sql(), "SELECT a FROM x") + self.assertEqual(expression.transform(lambda _: None), None) + + expression = parse_one("CAST(x AS FLOAT)") + + def remove_non_list_arg(node): + if isinstance(node, exp.DataType): + return None + return node + + self.assertEqual(expression.transform(remove_non_list_arg).sql(), "CAST(x AS )") + + expression = parse_one("SELECT a, b FROM x") + + def remove_all_columns(node): + if isinstance(node, exp.Column): + return None + return node + + self.assertEqual(expression.transform(remove_all_columns).sql(), "SELECT FROM x") + def test_replace(self): expression = parse_one("SELECT a, b FROM x") expression.find(exp.Column).replace(parse_one("c")) diff --git a/tests/test_parser.py b/tests/test_parser.py index 1054103..9e430e2 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -114,6 +114,9 @@ class TestParser(unittest.TestCase): with self.assertRaises(ParseError): parse_one("SELECT FROM x ORDER BY") + def test_parameter(self): + self.assertEqual(parse_one("SELECT @x, @@x, @1").sql(), "SELECT @x, @@x, @1") + def test_annotations(self): expression = parse_one( """ -- cgit v1.2.3