diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-06-22 18:53:31 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-06-22 18:53:31 +0000 |
commit | 20d090151fbc2e75394fc456f49f0078e59752d8 (patch) | |
tree | 084494962f092ff80f5ef8fdba1b917206abbc83 /tests | |
parent | Adding upstream version 16.2.1. (diff) | |
download | sqlglot-20d090151fbc2e75394fc456f49f0078e59752d8.tar.xz sqlglot-20d090151fbc2e75394fc456f49f0078e59752d8.zip |
Adding upstream version 16.4.0.upstream/16.4.0
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'tests')
-rw-r--r-- | tests/dataframe/integration/dataframe_validator.py | 6 | ||||
-rw-r--r-- | tests/dataframe/unit/test_dataframe_writer.py | 8 | ||||
-rw-r--r-- | tests/dataframe/unit/test_session.py | 14 | ||||
-rw-r--r-- | tests/dialects/test_bigquery.py | 38 | ||||
-rw-r--r-- | tests/dialects/test_clickhouse.py | 2 | ||||
-rw-r--r-- | tests/dialects/test_dialect.py | 10 | ||||
-rw-r--r-- | tests/dialects/test_duckdb.py | 1 | ||||
-rw-r--r-- | tests/dialects/test_hive.py | 31 | ||||
-rw-r--r-- | tests/dialects/test_mysql.py | 30 | ||||
-rw-r--r-- | tests/dialects/test_postgres.py | 14 | ||||
-rw-r--r-- | tests/dialects/test_presto.py | 27 | ||||
-rw-r--r-- | tests/dialects/test_redshift.py | 2 | ||||
-rw-r--r-- | tests/dialects/test_snowflake.py | 1 | ||||
-rw-r--r-- | tests/dialects/test_spark.py | 10 | ||||
-rw-r--r-- | tests/dialects/test_tsql.py | 2 | ||||
-rw-r--r-- | tests/fixtures/identity.sql | 2 | ||||
-rw-r--r-- | tests/fixtures/optimizer/optimizer.sql | 35 | ||||
-rw-r--r-- | tests/fixtures/optimizer/simplify.sql | 6 | ||||
-rw-r--r-- | tests/test_expressions.py | 1 | ||||
-rw-r--r-- | tests/test_schema.py | 11 |
20 files changed, 190 insertions, 61 deletions
diff --git a/tests/dataframe/integration/dataframe_validator.py b/tests/dataframe/integration/dataframe_validator.py index c84a342..22d4982 100644 --- a/tests/dataframe/integration/dataframe_validator.py +++ b/tests/dataframe/integration/dataframe_validator.py @@ -135,9 +135,9 @@ class DataFrameValidator(unittest.TestCase): data=district_data, schema=cls.sqlglot_district_schema ) cls.df_district.createOrReplaceTempView("district") - sqlglot.schema.add_table("employee", cls.sqlglot_employee_schema) - sqlglot.schema.add_table("store", cls.sqlglot_store_schema) - sqlglot.schema.add_table("district", cls.sqlglot_district_schema) + sqlglot.schema.add_table("employee", cls.sqlglot_employee_schema, dialect="spark") + sqlglot.schema.add_table("store", cls.sqlglot_store_schema, dialect="spark") + sqlglot.schema.add_table("district", cls.sqlglot_district_schema, dialect="spark") def setUp(self) -> None: warnings.filterwarnings("ignore", category=ResourceWarning) diff --git a/tests/dataframe/unit/test_dataframe_writer.py b/tests/dataframe/unit/test_dataframe_writer.py index 3f45468..303d2f9 100644 --- a/tests/dataframe/unit/test_dataframe_writer.py +++ b/tests/dataframe/unit/test_dataframe_writer.py @@ -30,7 +30,7 @@ class TestDataFrameWriter(DataFrameSQLValidator): @mock.patch("sqlglot.schema", MappingSchema()) def test_insertInto_byName(self): - sqlglot.schema.add_table("table_name", {"employee_id": "INT"}) + sqlglot.schema.add_table("table_name", {"employee_id": "INT"}, dialect="spark") df = self.df_employee.write.byName.insertInto("table_name") expected = "INSERT INTO table_name SELECT `a1`.`employee_id` AS `employee_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" self.compare_sql(df, expected) @@ -88,8 +88,8 @@ class TestDataFrameWriter(DataFrameSQLValidator): self.compare_sql(df, expected_statements) def test_quotes(self): - sqlglot.schema.add_table('"Test"', {'"ID"': "STRING"}) - df = self.spark.table('"Test"') + sqlglot.schema.add_table("`Test`", {"`ID`": "STRING"}, dialect="spark") + df = self.spark.table("`Test`") self.compare_sql( - df.select(df['"ID"']), ["SELECT `Test`.`ID` AS `ID` FROM `Test` AS `Test`"] + df.select(df["`ID`"]), ["SELECT `test`.`id` AS `id` FROM `test` AS `test`"] ) diff --git a/tests/dataframe/unit/test_session.py b/tests/dataframe/unit/test_session.py index 0970a2e..4c275e9 100644 --- a/tests/dataframe/unit/test_session.py +++ b/tests/dataframe/unit/test_session.py @@ -71,7 +71,7 @@ class TestDataframeSession(DataFrameSQLValidator): @mock.patch("sqlglot.schema", MappingSchema()) def test_sql_select_only(self): query = "SELECT cola, colb FROM table" - sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"}) + sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"}, dialect="spark") df = self.spark.sql(query) self.assertEqual( "SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`", @@ -80,17 +80,17 @@ class TestDataframeSession(DataFrameSQLValidator): @mock.patch("sqlglot.schema", MappingSchema()) def test_select_quoted(self): - sqlglot.schema.add_table('"TEST"', {"name": "string"}) + sqlglot.schema.add_table("`TEST`", {"name": "string"}, dialect="spark") self.assertEqual( - SparkSession().table('"TEST"').select(F.col("name")).sql(dialect="snowflake")[0], - '''SELECT "TEST"."name" AS "name" FROM "TEST" AS "TEST"''', + SparkSession().table("`TEST`").select(F.col("name")).sql(dialect="snowflake")[0], + '''SELECT "test"."name" AS "name" FROM "test" AS "test"''', ) @mock.patch("sqlglot.schema", MappingSchema()) def test_sql_with_aggs(self): query = "SELECT cola, colb FROM table" - sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"}) + sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"}, dialect="spark") df = self.spark.sql(query).groupBy(F.col("cola")).agg(F.sum("colb")) self.assertEqual( "WITH t38189 AS (SELECT cola, colb FROM table), t42330 AS (SELECT cola, colb FROM t38189) SELECT cola, SUM(colb) FROM t42330 GROUP BY cola", @@ -100,7 +100,7 @@ class TestDataframeSession(DataFrameSQLValidator): @mock.patch("sqlglot.schema", MappingSchema()) def test_sql_create(self): query = "CREATE TABLE new_table AS WITH t1 AS (SELECT cola, colb FROM table) SELECT cola, colb, FROM t1" - sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"}) + sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"}, dialect="spark") df = self.spark.sql(query) expected = "CREATE TABLE new_table AS SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`" self.compare_sql(df, expected) @@ -108,7 +108,7 @@ class TestDataframeSession(DataFrameSQLValidator): @mock.patch("sqlglot.schema", MappingSchema()) def test_sql_insert(self): query = "WITH t1 AS (SELECT cola, colb FROM table) INSERT INTO new_table SELECT cola, colb FROM t1" - sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"}) + sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"}, dialect="spark") df = self.spark.sql(query) expected = "INSERT INTO new_table SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`" self.compare_sql(df, expected) diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py index 1c8aa51..e05fca0 100644 --- a/tests/dialects/test_bigquery.py +++ b/tests/dialects/test_bigquery.py @@ -23,6 +23,14 @@ class TestBigQuery(Validator): self.validate_identity("SELECT b'abc'") self.validate_identity("""SELECT * FROM UNNEST(ARRAY<STRUCT<x INT64>>[1, 2])""") self.validate_identity("SELECT AS STRUCT 1 AS a, 2 AS b") + self.validate_all( + "SELECT AS STRUCT ARRAY(SELECT AS STRUCT b FROM x) AS y FROM z", + write={ + "": "SELECT AS STRUCT ARRAY(SELECT AS STRUCT b FROM x) AS y FROM z", + "bigquery": "SELECT AS STRUCT ARRAY(SELECT AS STRUCT b FROM x) AS y FROM z", + "duckdb": "SELECT {'y': ARRAY(SELECT {'b': b} FROM x)} FROM z", + }, + ) self.validate_identity("SELECT DISTINCT AS STRUCT 1 AS a, 2 AS b") self.validate_identity("SELECT AS VALUE STRUCT(1 AS a, 2 AS b)") self.validate_identity("SELECT STRUCT<ARRAY<STRING>>(['2023-01-17'])") @@ -117,6 +125,21 @@ class TestBigQuery(Validator): transpile("'\\'", read="bigquery") self.validate_all( + "r'x\\''", + write={ + "bigquery": "r'x\\''", + "hive": "'x\\''", + }, + ) + + self.validate_all( + "r'x\\y'", + write={ + "bigquery": "r'x\\y'", + "hive": "'x\\\\y'", + }, + ) + self.validate_all( "'\\\\'", write={ "bigquery": r"'\\'", @@ -458,6 +481,21 @@ class TestBigQuery(Validator): "SELECT * FROM UNNEST([1]) WITH OFFSET y", write={"bigquery": "SELECT * FROM UNNEST([1]) WITH OFFSET AS y"}, ) + self.validate_all( + "GENERATE_ARRAY(1, 4)", + read={"bigquery": "GENERATE_ARRAY(1, 4)"}, + write={"duckdb": "GENERATE_SERIES(1, 4)"}, + ) + self.validate_all( + "TO_JSON_STRING(x)", + read={"bigquery": "TO_JSON_STRING(x)"}, + write={ + "bigquery": "TO_JSON_STRING(x)", + "duckdb": "CAST(TO_JSON(x) AS TEXT)", + "presto": "JSON_FORMAT(x)", + "spark": "TO_JSON(x)", + }, + ) def test_user_defined_functions(self): self.validate_identity( diff --git a/tests/dialects/test_clickhouse.py b/tests/dialects/test_clickhouse.py index 7584c67..b0df4df 100644 --- a/tests/dialects/test_clickhouse.py +++ b/tests/dialects/test_clickhouse.py @@ -53,7 +53,7 @@ class TestClickhouse(Validator): ) self.validate_all( - "CONCAT(CASE WHEN COALESCE(a, '') IS NULL THEN COALESCE(a, '') ELSE CAST(COALESCE(a, '') AS TEXT) END, CASE WHEN COALESCE(b, '') IS NULL THEN COALESCE(b, '') ELSE CAST(COALESCE(b, '') AS TEXT) END)", + "CONCAT(CASE WHEN COALESCE(CAST(a AS TEXT), '') IS NULL THEN COALESCE(CAST(a AS TEXT), '') ELSE CAST(COALESCE(CAST(a AS TEXT), '') AS TEXT) END, CASE WHEN COALESCE(CAST(b AS TEXT), '') IS NULL THEN COALESCE(CAST(b AS TEXT), '') ELSE CAST(COALESCE(CAST(b AS TEXT), '') AS TEXT) END)", read={"postgres": "CONCAT(a, b)"}, ) self.validate_all( diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py index 8ffdf07..3ac05cf 100644 --- a/tests/dialects/test_dialect.py +++ b/tests/dialects/test_dialect.py @@ -485,7 +485,7 @@ class TestDialect(Validator): "duckdb": "CAST(x AS DATE)", "hive": "TO_DATE(x)", "postgres": "CAST(x AS DATE)", - "presto": "CAST(SUBSTR(CAST(x AS VARCHAR), 1, 10) AS DATE)", + "presto": "CAST(CAST(x AS TIMESTAMP) AS DATE)", "snowflake": "CAST(x AS DATE)", }, ) @@ -749,14 +749,14 @@ class TestDialect(Validator): "drill": "DATE_ADD(CAST('2021-02-01' AS DATE), INTERVAL 1 DAY)", "duckdb": "CAST('2021-02-01' AS DATE) + INTERVAL 1 DAY", "hive": "DATE_ADD('2021-02-01', 1)", - "presto": "DATE_ADD('DAY', 1, DATE_PARSE(SUBSTR('2021-02-01', 1, 10), '%Y-%m-%d'))", + "presto": "DATE_ADD('DAY', 1, CAST(CAST('2021-02-01' AS TIMESTAMP) AS DATE))", "spark": "DATE_ADD('2021-02-01', 1)", }, ) self.validate_all( "TS_OR_DS_ADD(x, 1, 'DAY')", write={ - "presto": "DATE_ADD('DAY', 1, DATE_PARSE(SUBSTR(CAST(x AS VARCHAR), 1, 10), '%Y-%m-%d'))", + "presto": "DATE_ADD('DAY', 1, CAST(CAST(x AS TIMESTAMP) AS DATE))", "hive": "DATE_ADD(x, 1)", }, ) @@ -1192,7 +1192,7 @@ class TestDialect(Validator): }, ) self.validate_all( - "COALESCE(a, '')", + "COALESCE(CAST(a AS TEXT), '')", read={ "drill": "CONCAT(a)", "duckdb": "CONCAT(a)", @@ -1300,7 +1300,9 @@ class TestDialect(Validator): self.validate_all( "SELECT x FROM y LIMIT 10", read={ + "teradata": "SELECT TOP 10 x FROM y", "tsql": "SELECT TOP 10 x FROM y", + "snowflake": "SELECT TOP 10 x FROM y", }, write={ "sqlite": "SELECT x FROM y LIMIT 10", diff --git a/tests/dialects/test_duckdb.py b/tests/dialects/test_duckdb.py index f0caafc..4065f81 100644 --- a/tests/dialects/test_duckdb.py +++ b/tests/dialects/test_duckdb.py @@ -444,6 +444,7 @@ class TestDuckDB(Validator): def test_array(self): self.validate_identity("ARRAY(SELECT id FROM t)") + self.validate_identity("ARRAY((SELECT id FROM t))") def test_cast(self): self.validate_identity("CAST(x AS REAL)") diff --git a/tests/dialects/test_hive.py b/tests/dialects/test_hive.py index f6cc224..c9bcf16 100644 --- a/tests/dialects/test_hive.py +++ b/tests/dialects/test_hive.py @@ -4,17 +4,6 @@ from tests.dialects.test_dialect import Validator class TestHive(Validator): dialect = "hive" - def test_hive(self): - self.validate_identity("SELECT * FROM test DISTRIBUTE BY y SORT BY x DESC ORDER BY l") - self.validate_identity( - "SELECT * FROM test WHERE RAND() <= 0.1 DISTRIBUTE BY RAND() SORT BY RAND()" - ) - self.validate_identity("(SELECT 1 UNION SELECT 2) DISTRIBUTE BY z") - self.validate_identity("(SELECT 1 UNION SELECT 2) DISTRIBUTE BY z SORT BY x") - self.validate_identity("(SELECT 1 UNION SELECT 2) CLUSTER BY y DESC") - self.validate_identity("SELECT * FROM test CLUSTER BY y") - self.validate_identity("(SELECT 1 UNION SELECT 2) SORT BY z") - def test_bits(self): self.validate_all( "x & 1", @@ -288,7 +277,7 @@ class TestHive(Validator): "DATEDIFF(a, b)", write={ "duckdb": "DATE_DIFF('day', CAST(b AS DATE), CAST(a AS DATE))", - "presto": "DATE_DIFF('day', CAST(SUBSTR(CAST(b AS VARCHAR), 1, 10) AS DATE), CAST(SUBSTR(CAST(a AS VARCHAR), 1, 10) AS DATE))", + "presto": "DATE_DIFF('day', CAST(CAST(b AS TIMESTAMP) AS DATE), CAST(CAST(a AS TIMESTAMP) AS DATE))", "hive": "DATEDIFF(TO_DATE(a), TO_DATE(b))", "spark": "DATEDIFF(TO_DATE(a), TO_DATE(b))", "": "DATEDIFF(TS_OR_DS_TO_DATE(a), TS_OR_DS_TO_DATE(b))", @@ -316,7 +305,7 @@ class TestHive(Validator): "DATE_ADD('2020-01-01', 1)", write={ "duckdb": "CAST('2020-01-01' AS DATE) + INTERVAL 1 DAY", - "presto": "DATE_ADD('DAY', 1, DATE_PARSE(SUBSTR('2020-01-01', 1, 10), '%Y-%m-%d'))", + "presto": "DATE_ADD('DAY', 1, CAST(CAST('2020-01-01' AS TIMESTAMP) AS DATE))", "hive": "DATE_ADD('2020-01-01', 1)", "spark": "DATE_ADD('2020-01-01', 1)", "": "TS_OR_DS_ADD('2020-01-01', 1, 'DAY')", @@ -326,7 +315,7 @@ class TestHive(Validator): "DATE_SUB('2020-01-01', 1)", write={ "duckdb": "CAST('2020-01-01' AS DATE) + INTERVAL (1 * -1) DAY", - "presto": "DATE_ADD('DAY', 1 * -1, DATE_PARSE(SUBSTR('2020-01-01', 1, 10), '%Y-%m-%d'))", + "presto": "DATE_ADD('DAY', 1 * -1, CAST(CAST('2020-01-01' AS TIMESTAMP) AS DATE))", "hive": "DATE_ADD('2020-01-01', 1 * -1)", "spark": "DATE_ADD('2020-01-01', 1 * -1)", "": "TS_OR_DS_ADD('2020-01-01', 1 * -1, 'DAY')", @@ -341,7 +330,7 @@ class TestHive(Validator): "DATEDIFF(TO_DATE(y), x)", write={ "duckdb": "DATE_DIFF('day', CAST(x AS DATE), CAST(CAST(y AS DATE) AS DATE))", - "presto": "DATE_DIFF('day', CAST(SUBSTR(CAST(x AS VARCHAR), 1, 10) AS DATE), CAST(SUBSTR(CAST(CAST(SUBSTR(CAST(y AS VARCHAR), 1, 10) AS DATE) AS VARCHAR), 1, 10) AS DATE))", + "presto": "DATE_DIFF('day', CAST(CAST(x AS TIMESTAMP) AS DATE), CAST(CAST(CAST(CAST(y AS TIMESTAMP) AS DATE) AS TIMESTAMP) AS DATE))", "hive": "DATEDIFF(TO_DATE(TO_DATE(y)), TO_DATE(x))", "spark": "DATEDIFF(TO_DATE(TO_DATE(y)), TO_DATE(x))", "": "DATEDIFF(TS_OR_DS_TO_DATE(TS_OR_DS_TO_DATE(y)), TS_OR_DS_TO_DATE(x))", @@ -363,7 +352,7 @@ class TestHive(Validator): f"{unit}(x)", write={ "duckdb": f"{unit}(CAST(x AS DATE))", - "presto": f"{unit}(CAST(SUBSTR(CAST(x AS VARCHAR), 1, 10) AS DATE))", + "presto": f"{unit}(CAST(CAST(x AS TIMESTAMP) AS DATE))", "hive": f"{unit}(TO_DATE(x))", "spark": f"{unit}(TO_DATE(x))", }, @@ -381,6 +370,16 @@ class TestHive(Validator): ) def test_hive(self): + self.validate_identity("SELECT * FROM test DISTRIBUTE BY y SORT BY x DESC ORDER BY l") + self.validate_identity( + "SELECT * FROM test WHERE RAND() <= 0.1 DISTRIBUTE BY RAND() SORT BY RAND()" + ) + self.validate_identity("(SELECT 1 UNION SELECT 2) DISTRIBUTE BY z") + self.validate_identity("(SELECT 1 UNION SELECT 2) DISTRIBUTE BY z SORT BY x") + self.validate_identity("(SELECT 1 UNION SELECT 2) CLUSTER BY y DESC") + self.validate_identity("SELECT * FROM test CLUSTER BY y") + + self.validate_identity("(SELECT 1 UNION SELECT 2) SORT BY z") self.validate_identity( "INSERT OVERWRITE TABLE zipcodes PARTITION(state = '0') VALUES (896, 'US', 'TAMPA', 33607)" ) diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py index 0b9c8b7..b8f7af0 100644 --- a/tests/dialects/test_mysql.py +++ b/tests/dialects/test_mysql.py @@ -6,6 +6,8 @@ class TestMySQL(Validator): dialect = "mysql" def test_ddl(self): + self.validate_identity("UPDATE items SET items.price = 0 WHERE items.id >= 5 LIMIT 10") + self.validate_identity("DELETE FROM t WHERE a <= 10 LIMIT 10") self.validate_identity( "INSERT INTO x VALUES (1, 'a', 2.0) ON DUPLICATE KEY UPDATE SET x.id = 1" ) @@ -61,6 +63,22 @@ class TestMySQL(Validator): "SELECT * FROM t1, t2, t3 FOR SHARE OF t1 NOWAIT FOR UPDATE OF t2, t3 SKIP LOCKED" ) + # Index hints + self.validate_identity( + "SELECT * FROM table1 USE INDEX (col1_index, col2_index) WHERE col1 = 1 AND col2 = 2 AND col3 = 3" + ) + self.validate_identity( + "SELECT * FROM table1 IGNORE INDEX (col3_index) WHERE col1 = 1 AND col2 = 2 AND col3 = 3" + ) + self.validate_identity( + "SELECT * FROM t1 USE INDEX (i1) IGNORE INDEX FOR ORDER BY (i2) ORDER BY a" + ) + self.validate_identity("SELECT * FROM t1 USE INDEX (i1) USE INDEX (i1, i1)") + self.validate_identity("SELECT * FROM t1 USE INDEX FOR JOIN (i1) FORCE INDEX FOR JOIN (i2)") + self.validate_identity( + "SELECT * FROM t1 USE INDEX () IGNORE INDEX (i2) USE INDEX (i1) USE INDEX (i2)" + ) + # SET Commands self.validate_identity("SET @var_name = expr") self.validate_identity("SET @name = 43") @@ -80,12 +98,6 @@ class TestMySQL(Validator): self.validate_identity("SET @@SESSION.max_join_size = DEFAULT") self.validate_identity("SET @@SESSION.max_join_size = @@GLOBAL.max_join_size") self.validate_identity("SET @x = 1, SESSION sql_mode = ''") - self.validate_identity( - "SET GLOBAL sort_buffer_size = 1000000, SESSION sort_buffer_size = 1000000" - ) - self.validate_identity( - "SET @@GLOBAL.sort_buffer_size = 1000000, @@LOCAL.sort_buffer_size = 1000000" - ) self.validate_identity("SET GLOBAL max_connections = 1000, sort_buffer_size = 1000000") self.validate_identity("SET @@GLOBAL.sort_buffer_size = 50000, sort_buffer_size = 1000000") self.validate_identity("SET CHARACTER SET 'utf8'") @@ -101,6 +113,12 @@ class TestMySQL(Validator): self.validate_identity("SET GLOBAL TRANSACTION ISOLATION LEVEL REPEATABLE READ, READ WRITE") self.validate_identity("SELECT SCHEMA()") self.validate_identity("SELECT DATABASE()") + self.validate_identity( + "SET GLOBAL sort_buffer_size = 1000000, SESSION sort_buffer_size = 1000000" + ) + self.validate_identity( + "SET @@GLOBAL.sort_buffer_size = 1000000, @@LOCAL.sort_buffer_size = 1000000" + ) def test_types(self): self.validate_all( diff --git a/tests/dialects/test_postgres.py b/tests/dialects/test_postgres.py index 4e57b36..c391052 100644 --- a/tests/dialects/test_postgres.py +++ b/tests/dialects/test_postgres.py @@ -578,12 +578,18 @@ class TestPostgres(Validator): def test_string_concat(self): self.validate_all( + "SELECT CONCAT('abcde', 2, NULL, 22)", + write={ + "postgres": "SELECT CONCAT(COALESCE(CAST('abcde' AS TEXT), ''), COALESCE(CAST(2 AS TEXT), ''), COALESCE(CAST(NULL AS TEXT), ''), COALESCE(CAST(22 AS TEXT), ''))", + }, + ) + self.validate_all( "CONCAT(a, b)", write={ - "": "CONCAT(COALESCE(a, ''), COALESCE(b, ''))", - "duckdb": "CONCAT(COALESCE(a, ''), COALESCE(b, ''))", - "postgres": "CONCAT(COALESCE(a, ''), COALESCE(b, ''))", - "presto": "CONCAT(CAST(COALESCE(a, '') AS VARCHAR), CAST(COALESCE(b, '') AS VARCHAR))", + "": "CONCAT(COALESCE(CAST(a AS TEXT), ''), COALESCE(CAST(b AS TEXT), ''))", + "duckdb": "CONCAT(COALESCE(CAST(a AS TEXT), ''), COALESCE(CAST(b AS TEXT), ''))", + "postgres": "CONCAT(COALESCE(CAST(a AS TEXT), ''), COALESCE(CAST(b AS TEXT), ''))", + "presto": "CONCAT(CAST(COALESCE(CAST(a AS VARCHAR), '') AS VARCHAR), CAST(COALESCE(CAST(b AS VARCHAR), '') AS VARCHAR))", }, ) self.validate_all( diff --git a/tests/dialects/test_presto.py b/tests/dialects/test_presto.py index 4f37be5..852b494 100644 --- a/tests/dialects/test_presto.py +++ b/tests/dialects/test_presto.py @@ -7,11 +7,11 @@ class TestPresto(Validator): def test_cast(self): self.validate_all( - "SELECT DATE_DIFF('week', CAST(SUBSTR(CAST('2009-01-01' AS VARCHAR), 1, 10) AS DATE), CAST(SUBSTR(CAST('2009-12-31' AS VARCHAR), 1, 10) AS DATE))", + "SELECT DATE_DIFF('week', CAST(CAST('2009-01-01' AS TIMESTAMP) AS DATE), CAST(CAST('2009-12-31' AS TIMESTAMP) AS DATE))", read={"redshift": "SELECT DATEDIFF(week, '2009-01-01', '2009-12-31')"}, ) self.validate_all( - "SELECT DATE_ADD('month', 18, CAST(SUBSTR(CAST('2008-02-28' AS VARCHAR), 1, 10) AS DATE))", + "SELECT DATE_ADD('month', 18, CAST(CAST('2008-02-28' AS TIMESTAMP) AS DATE))", read={"redshift": "SELECT DATEADD(month, 18, '2008-02-28')"}, ) self.validate_all( @@ -664,16 +664,31 @@ class TestPresto(Validator): "spark": "TO_JSON(x)", }, write={ + "bigquery": "TO_JSON_STRING(x)", + "duckdb": "CAST(TO_JSON(x) AS TEXT)", "presto": "JSON_FORMAT(x)", "spark": "TO_JSON(x)", }, ) - self.validate_all( - "JSON_FORMAT(JSON 'x')", + """JSON_FORMAT(JSON '"x"')""", + write={ + "bigquery": """TO_JSON_STRING(CAST('"x"' AS JSON))""", + "duckdb": """CAST(TO_JSON(CAST('"x"' AS JSON)) AS TEXT)""", + "presto": """JSON_FORMAT(CAST('"x"' AS JSON))""", + "spark": """REGEXP_EXTRACT(TO_JSON(FROM_JSON('["x"]', SCHEMA_OF_JSON('["x"]'))), '^.(.*).$', 1)""", + }, + ) + self.validate_all( + """SELECT JSON_FORMAT(JSON '{"a": 1, "b": "c"}')""", + write={ + "spark": """SELECT REGEXP_EXTRACT(TO_JSON(FROM_JSON('[{"a": 1, "b": "c"}]', SCHEMA_OF_JSON('[{"a": 1, "b": "c"}]'))), '^.(.*).$', 1)""", + }, + ) + self.validate_all( + """SELECT JSON_FORMAT(JSON '[1, 2, 3]')""", write={ - "presto": "JSON_FORMAT(CAST('x' AS JSON))", - "spark": "TO_JSON('x')", + "spark": "SELECT REGEXP_EXTRACT(TO_JSON(FROM_JSON('[[1, 2, 3]]', SCHEMA_OF_JSON('[[1, 2, 3]]'))), '^.(.*).$', 1)", }, ) diff --git a/tests/dialects/test_redshift.py b/tests/dialects/test_redshift.py index f4efe24..88168ab 100644 --- a/tests/dialects/test_redshift.py +++ b/tests/dialects/test_redshift.py @@ -180,7 +180,7 @@ class TestRedshift(Validator): "DATEDIFF('day', a, b)", write={ "redshift": "DATEDIFF(day, a, b)", - "presto": "DATE_DIFF('day', CAST(SUBSTR(CAST(a AS VARCHAR), 1, 10) AS DATE), CAST(SUBSTR(CAST(b AS VARCHAR), 1, 10) AS DATE))", + "presto": "DATE_DIFF('day', CAST(CAST(a AS TIMESTAMP) AS DATE), CAST(CAST(b AS TIMESTAMP) AS DATE))", }, ) self.validate_all( diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py index 426e188..0514149 100644 --- a/tests/dialects/test_snowflake.py +++ b/tests/dialects/test_snowflake.py @@ -575,6 +575,7 @@ class TestSnowflake(Validator): ) def test_ddl(self): + self.validate_identity("CREATE OR REPLACE VIEW foo (uid) COPY GRANTS AS (SELECT 1)") self.validate_identity("CREATE TABLE geospatial_table (id INT, g GEOGRAPHY)") self.validate_identity("CREATE MATERIALIZED VIEW a COMMENT='...' AS SELECT 1 FROM x") self.validate_identity("CREATE DATABASE mytestdb_clone CLONE mytestdb") diff --git a/tests/dialects/test_spark.py b/tests/dialects/test_spark.py index 7c8ca1b..54c39e7 100644 --- a/tests/dialects/test_spark.py +++ b/tests/dialects/test_spark.py @@ -130,7 +130,7 @@ TBLPROPERTIES ( write={ "duckdb": "CAST(x AS DATE)", "hive": "TO_DATE(x)", - "presto": "CAST(SUBSTR(CAST(x AS VARCHAR), 1, 10) AS DATE)", + "presto": "CAST(CAST(x AS TIMESTAMP) AS DATE)", "spark": "TO_DATE(x)", }, ) @@ -268,10 +268,10 @@ TBLPROPERTIES ( write={ "databricks": "SELECT DATEDIFF(MONTH, TO_DATE('2020-01-01'), TO_DATE('2020-03-05'))", "hive": "SELECT MONTHS_BETWEEN(TO_DATE('2020-03-05'), TO_DATE('2020-01-01'))", - "presto": "SELECT DATE_DIFF('MONTH', CAST(SUBSTR(CAST('2020-01-01' AS VARCHAR), 1, 10) AS DATE), CAST(SUBSTR(CAST('2020-03-05' AS VARCHAR), 1, 10) AS DATE))", + "presto": "SELECT DATE_DIFF('MONTH', CAST(CAST('2020-01-01' AS TIMESTAMP) AS DATE), CAST(CAST('2020-03-05' AS TIMESTAMP) AS DATE))", "spark": "SELECT DATEDIFF(MONTH, TO_DATE('2020-01-01'), TO_DATE('2020-03-05'))", "spark2": "SELECT MONTHS_BETWEEN(TO_DATE('2020-03-05'), TO_DATE('2020-01-01'))", - "trino": "SELECT DATE_DIFF('MONTH', CAST(SUBSTR(CAST('2020-01-01' AS VARCHAR), 1, 10) AS DATE), CAST(SUBSTR(CAST('2020-03-05' AS VARCHAR), 1, 10) AS DATE))", + "trino": "SELECT DATE_DIFF('MONTH', CAST(CAST('2020-01-01' AS TIMESTAMP) AS DATE), CAST(CAST('2020-03-05' AS TIMESTAMP) AS DATE))", }, ) @@ -359,7 +359,7 @@ TBLPROPERTIES ( "MONTH('2021-03-01')", write={ "duckdb": "MONTH(CAST('2021-03-01' AS DATE))", - "presto": "MONTH(CAST(SUBSTR(CAST('2021-03-01' AS VARCHAR), 1, 10) AS DATE))", + "presto": "MONTH(CAST(CAST('2021-03-01' AS TIMESTAMP) AS DATE))", "hive": "MONTH(TO_DATE('2021-03-01'))", "spark": "MONTH(TO_DATE('2021-03-01'))", }, @@ -368,7 +368,7 @@ TBLPROPERTIES ( "YEAR('2021-03-01')", write={ "duckdb": "YEAR(CAST('2021-03-01' AS DATE))", - "presto": "YEAR(CAST(SUBSTR(CAST('2021-03-01' AS VARCHAR), 1, 10) AS DATE))", + "presto": "YEAR(CAST(CAST('2021-03-01' AS TIMESTAMP) AS DATE))", "hive": "YEAR(TO_DATE('2021-03-01'))", "spark": "YEAR(TO_DATE('2021-03-01'))", }, diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py index 8789aed..953d64d 100644 --- a/tests/dialects/test_tsql.py +++ b/tests/dialects/test_tsql.py @@ -6,6 +6,8 @@ class TestTSQL(Validator): dialect = "tsql" def test_tsql(self): + self.validate_identity("SELECT * FROM t WITH (TABLOCK, INDEX(myindex))") + self.validate_identity("SELECT * FROM t WITH (NOWAIT)") self.validate_identity("SELECT CASE WHEN a > 1 THEN b END") self.validate_identity("SELECT * FROM taxi ORDER BY 1 OFFSET 0 ROWS FETCH NEXT 3 ROWS ONLY") self.validate_identity("END") diff --git a/tests/fixtures/identity.sql b/tests/fixtures/identity.sql index e0ea9cb..ff3162b 100644 --- a/tests/fixtures/identity.sql +++ b/tests/fixtures/identity.sql @@ -519,8 +519,6 @@ SELECT student, score FROM tests CROSS JOIN UNNEST(scores) AS t(a, b) SELECT student, score FROM tests CROSS JOIN UNNEST(scores) WITH ORDINALITY AS t(a, b) SELECT student, score FROM tests CROSS JOIN UNNEST(x.scores) AS t(score) SELECT student, score FROM tests CROSS JOIN UNNEST(ARRAY(x.scores)) AS t(score) -SELECT * FROM t WITH (TABLOCK, INDEX(myindex)) -SELECT * FROM t WITH (NOWAIT) CREATE TABLE foo AS (SELECT 1) UNION ALL (SELECT 2) CREATE TABLE foo (id INT PRIMARY KEY ASC) CREATE TABLE a.b AS SELECT 1 diff --git a/tests/fixtures/optimizer/optimizer.sql b/tests/fixtures/optimizer/optimizer.sql index 0cb1a58..214535a 100644 --- a/tests/fixtures/optimizer/optimizer.sql +++ b/tests/fixtures/optimizer/optimizer.sql @@ -646,3 +646,38 @@ CROSS JOIN LATERAL ( "l"."log_date" DESC NULLS LAST LIMIT 1 ) AS "l"; + +# title: bigquery column identifiers are case-insensitive +# execute: false +# dialect: bigquery +WITH cte AS ( + SELECT + refresh_date AS `reFREsh_date`, + term AS `TeRm`, + `rank` + FROM `bigquery-public-data.GooGle_tReNDs.TOp_TeRmS` +) +SELECT + refresh_date AS `Day`, + term AS Top_Term, + rank, +FROM cte +WHERE + rank = 1 + AND refresh_date >= DATE_SUB(CURRENT_DATE(), INTERVAL 2 WEEK) +GROUP BY `dAy`, `top_term`, rank +ORDER BY `DaY` DESC; +SELECT + `TOp_TeRmS`.`refresh_date` AS `day`, + `TOp_TeRmS`.`term` AS `top_term`, + `TOp_TeRmS`.`rank` AS `rank` +FROM `bigquery-public-data`.`GooGle_tReNDs`.`TOp_TeRmS` AS `TOp_TeRmS` +WHERE + `TOp_TeRmS`.`rank` = 1 + AND CAST(`TOp_TeRmS`.`refresh_date` AS DATE) >= DATE_SUB(CURRENT_DATE, INTERVAL 2 WEEK) +GROUP BY + `TOp_TeRmS`.`refresh_date`, + `TOp_TeRmS`.`term`, + `TOp_TeRmS`.`rank` +ORDER BY + `day` DESC; diff --git a/tests/fixtures/optimizer/simplify.sql b/tests/fixtures/optimizer/simplify.sql index 5c8d371..e0aded4 100644 --- a/tests/fixtures/optimizer/simplify.sql +++ b/tests/fixtures/optimizer/simplify.sql @@ -240,6 +240,12 @@ A AND B AND C; SELECT x WHERE TRUE; SELECT x; +SELECT x FROM y LEFT JOIN z ON TRUE; +SELECT x FROM y CROSS JOIN z; + +SELECT x FROM y JOIN z USING (x); +SELECT x FROM y JOIN z USING (x); + -------------------------------------- -- Parenthesis removal -------------------------------------- diff --git a/tests/test_expressions.py b/tests/test_expressions.py index c9b5279..e7a37f3 100644 --- a/tests/test_expressions.py +++ b/tests/test_expressions.py @@ -813,6 +813,7 @@ FROM foo""", self.assertEqual( exp.DataType.build("struct<x int>", dialect="spark").sql(), "STRUCT<x INT>" ) + self.assertEqual(exp.DataType.build("USER-DEFINED").sql(), "USER-DEFINED") def test_rename_table(self): self.assertEqual( diff --git a/tests/test_schema.py b/tests/test_schema.py index b03e7e7..23690b9 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -201,7 +201,7 @@ class TestSchema(unittest.TestCase): def test_schema_normalization(self): schema = MappingSchema( schema={"x": {"`y`": {"Z": {"a": "INT", "`B`": "VARCHAR"}, "w": {"C": "INT"}}}}, - dialect="spark", + dialect="clickhouse", ) table_z = exp.Table(this="z", db="y", catalog="x") @@ -228,4 +228,11 @@ class TestSchema(unittest.TestCase): # Check that the correct dialect is used when calling schema methods schema = MappingSchema(schema={"[Fo]": {"x": "int"}}, dialect="tsql") - self.assertEqual(schema.column_names("[Fo]"), schema.column_names("`Fo`", dialect="spark")) + self.assertEqual( + schema.column_names("[Fo]"), schema.column_names("`Fo`", dialect="clickhouse") + ) + + # Check that all column identifiers are normalized to lowercase for BigQuery, even quoted + # ones. Also, ensure that tables aren't normalized, since they're case-sensitive by default. + schema = MappingSchema(schema={"Foo": {"`BaR`": "int"}}, dialect="bigquery") + self.assertEqual(schema.column_names("Foo"), ["bar"]) |