From 28cc22419e32a65fea2d1678400265b8cabc3aff Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Thu, 15 Sep 2022 18:46:17 +0200 Subject: Adding upstream version 6.0.4. Signed-off-by: Daniel Baumann --- tests/__init__.py | 0 tests/dialects/__init__.py | 0 tests/dialects/test_bigquery.py | 238 +++ tests/dialects/test_clickhouse.py | 25 + tests/dialects/test_dialect.py | 981 +++++++++++ tests/dialects/test_duckdb.py | 249 +++ tests/dialects/test_hive.py | 541 ++++++ tests/dialects/test_mysql.py | 79 + tests/dialects/test_postgres.py | 93 + tests/dialects/test_presto.py | 422 +++++ tests/dialects/test_snowflake.py | 145 ++ tests/dialects/test_spark.py | 226 +++ tests/dialects/test_sqlite.py | 72 + tests/dialects/test_starrocks.py | 8 + tests/dialects/test_tableau.py | 62 + tests/fixtures/identity.sql | 514 ++++++ tests/fixtures/optimizer/eliminate_subqueries.sql | 42 + .../optimizer/expand_multi_table_selects.sql | 11 + tests/fixtures/optimizer/isolate_table_selects.sql | 20 + tests/fixtures/optimizer/normalize.sql | 41 + tests/fixtures/optimizer/optimize_joins.sql | 20 + tests/fixtures/optimizer/optimizer.sql | 148 ++ tests/fixtures/optimizer/pushdown_predicates.sql | 32 + tests/fixtures/optimizer/pushdown_projections.sql | 41 + tests/fixtures/optimizer/qualify_columns.sql | 233 +++ .../optimizer/qualify_columns__invalid.sql | 14 + tests/fixtures/optimizer/qualify_tables.sql | 17 + tests/fixtures/optimizer/quote_identities.sql | 8 + tests/fixtures/optimizer/simplify.sql | 350 ++++ tests/fixtures/optimizer/tpc-h/customer.csv.gz | Bin 0 -> 125178 bytes tests/fixtures/optimizer/tpc-h/lineitem.csv.gz | Bin 0 -> 304069 bytes tests/fixtures/optimizer/tpc-h/nation.csv.gz | Bin 0 -> 1002 bytes tests/fixtures/optimizer/tpc-h/orders.csv.gz | Bin 0 -> 66113 bytes tests/fixtures/optimizer/tpc-h/part.csv.gz | Bin 0 -> 251365 bytes tests/fixtures/optimizer/tpc-h/partsupp.csv.gz | Bin 0 -> 303483 bytes tests/fixtures/optimizer/tpc-h/region.csv.gz | Bin 0 -> 284 bytes tests/fixtures/optimizer/tpc-h/supplier.csv.gz | Bin 0 -> 317596 bytes tests/fixtures/optimizer/tpc-h/tpc-h.sql | 1810 ++++++++++++++++++++ tests/fixtures/optimizer/unnest_subqueries.sql | 206 +++ tests/fixtures/partial.sql | 8 + tests/fixtures/pretty.sql | 285 +++ tests/helpers.py | 130 ++ tests/test_build.py | 384 +++++ tests/test_diff.py | 137 ++ tests/test_docs.py | 30 + tests/test_executor.py | 72 + tests/test_expressions.py | 415 +++++ tests/test_generator.py | 30 + tests/test_helper.py | 31 + tests/test_optimizer.py | 276 +++ tests/test_parser.py | 195 +++ tests/test_time.py | 14 + tests/test_transforms.py | 16 + tests/test_transpile.py | 349 ++++ 54 files changed, 9020 insertions(+) create mode 100644 tests/__init__.py create mode 100644 tests/dialects/__init__.py create mode 100644 tests/dialects/test_bigquery.py create mode 100644 tests/dialects/test_clickhouse.py create mode 100644 tests/dialects/test_dialect.py create mode 100644 tests/dialects/test_duckdb.py create mode 100644 tests/dialects/test_hive.py create mode 100644 tests/dialects/test_mysql.py create mode 100644 tests/dialects/test_postgres.py create mode 100644 tests/dialects/test_presto.py create mode 100644 tests/dialects/test_snowflake.py create mode 100644 tests/dialects/test_spark.py create mode 100644 tests/dialects/test_sqlite.py create mode 100644 tests/dialects/test_starrocks.py create mode 100644 tests/dialects/test_tableau.py create mode 100644 tests/fixtures/identity.sql create mode 100644 tests/fixtures/optimizer/eliminate_subqueries.sql create mode 100644 tests/fixtures/optimizer/expand_multi_table_selects.sql create mode 100644 tests/fixtures/optimizer/isolate_table_selects.sql create mode 100644 tests/fixtures/optimizer/normalize.sql create mode 100644 tests/fixtures/optimizer/optimize_joins.sql create mode 100644 tests/fixtures/optimizer/optimizer.sql create mode 100644 tests/fixtures/optimizer/pushdown_predicates.sql create mode 100644 tests/fixtures/optimizer/pushdown_projections.sql create mode 100644 tests/fixtures/optimizer/qualify_columns.sql create mode 100644 tests/fixtures/optimizer/qualify_columns__invalid.sql create mode 100644 tests/fixtures/optimizer/qualify_tables.sql create mode 100644 tests/fixtures/optimizer/quote_identities.sql create mode 100644 tests/fixtures/optimizer/simplify.sql create mode 100644 tests/fixtures/optimizer/tpc-h/customer.csv.gz create mode 100644 tests/fixtures/optimizer/tpc-h/lineitem.csv.gz create mode 100644 tests/fixtures/optimizer/tpc-h/nation.csv.gz create mode 100644 tests/fixtures/optimizer/tpc-h/orders.csv.gz create mode 100644 tests/fixtures/optimizer/tpc-h/part.csv.gz create mode 100644 tests/fixtures/optimizer/tpc-h/partsupp.csv.gz create mode 100644 tests/fixtures/optimizer/tpc-h/region.csv.gz create mode 100644 tests/fixtures/optimizer/tpc-h/supplier.csv.gz create mode 100644 tests/fixtures/optimizer/tpc-h/tpc-h.sql create mode 100644 tests/fixtures/optimizer/unnest_subqueries.sql create mode 100644 tests/fixtures/partial.sql create mode 100644 tests/fixtures/pretty.sql create mode 100644 tests/helpers.py create mode 100644 tests/test_build.py create mode 100644 tests/test_diff.py create mode 100644 tests/test_docs.py create mode 100644 tests/test_executor.py create mode 100644 tests/test_expressions.py create mode 100644 tests/test_generator.py create mode 100644 tests/test_helper.py create mode 100644 tests/test_optimizer.py create mode 100644 tests/test_parser.py create mode 100644 tests/test_time.py create mode 100644 tests/test_transforms.py create mode 100644 tests/test_transpile.py (limited to 'tests') diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/dialects/__init__.py b/tests/dialects/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py new file mode 100644 index 0000000..1337c3d --- /dev/null +++ b/tests/dialects/test_bigquery.py @@ -0,0 +1,238 @@ +from sqlglot import ErrorLevel, ParseError, UnsupportedError, transpile +from tests.dialects.test_dialect import Validator + + +class TestBigQuery(Validator): + dialect = "bigquery" + + def test_bigquery(self): + self.validate_all( + '"""x"""', + write={ + "bigquery": "'x'", + "duckdb": "'x'", + "presto": "'x'", + "hive": "'x'", + "spark": "'x'", + }, + ) + self.validate_all( + '"""x\'"""', + write={ + "bigquery": "'x\\''", + "duckdb": "'x'''", + "presto": "'x'''", + "hive": "'x\\''", + "spark": "'x\\''", + }, + ) + self.validate_all( + r'r"""/\*.*\*/"""', + write={ + "bigquery": r"'/\\*.*\\*/'", + "duckdb": r"'/\*.*\*/'", + "presto": r"'/\*.*\*/'", + "hive": r"'/\\*.*\\*/'", + "spark": r"'/\\*.*\\*/'", + }, + ) + self.validate_all( + R'R"""/\*.*\*/"""', + write={ + "bigquery": R"'/\\*.*\\*/'", + "duckdb": R"'/\*.*\*/'", + "presto": R"'/\*.*\*/'", + "hive": R"'/\\*.*\\*/'", + "spark": R"'/\\*.*\\*/'", + }, + ) + self.validate_all( + "CAST(a AS INT64)", + write={ + "bigquery": "CAST(a AS INT64)", + "duckdb": "CAST(a AS BIGINT)", + "presto": "CAST(a AS BIGINT)", + "hive": "CAST(a AS BIGINT)", + "spark": "CAST(a AS LONG)", + }, + ) + self.validate_all( + "CAST(a AS NUMERIC)", + write={ + "bigquery": "CAST(a AS NUMERIC)", + "duckdb": "CAST(a AS DECIMAL)", + "presto": "CAST(a AS DECIMAL)", + "hive": "CAST(a AS DECIMAL)", + "spark": "CAST(a AS DECIMAL)", + }, + ) + self.validate_all( + "[1, 2, 3]", + read={ + "duckdb": "LIST_VALUE(1, 2, 3)", + "presto": "ARRAY[1, 2, 3]", + "hive": "ARRAY(1, 2, 3)", + "spark": "ARRAY(1, 2, 3)", + }, + write={ + "bigquery": "[1, 2, 3]", + "duckdb": "LIST_VALUE(1, 2, 3)", + "presto": "ARRAY[1, 2, 3]", + "hive": "ARRAY(1, 2, 3)", + "spark": "ARRAY(1, 2, 3)", + }, + ) + self.validate_all( + "SELECT * FROM UNNEST(['7', '14']) AS x", + read={ + "spark": "SELECT * FROM UNNEST(ARRAY('7', '14')) AS (x)", + }, + write={ + "bigquery": "SELECT * FROM UNNEST(['7', '14']) AS x", + "presto": "SELECT * FROM UNNEST(ARRAY['7', '14']) AS (x)", + "hive": "SELECT * FROM UNNEST(ARRAY('7', '14')) AS (x)", + "spark": "SELECT * FROM UNNEST(ARRAY('7', '14')) AS (x)", + }, + ) + + self.validate_all( + "x IS unknown", + write={ + "bigquery": "x IS NULL", + "duckdb": "x IS NULL", + "presto": "x IS NULL", + "hive": "x IS NULL", + "spark": "x IS NULL", + }, + ) + self.validate_all( + "current_datetime", + write={ + "bigquery": "CURRENT_DATETIME()", + "duckdb": "CURRENT_DATETIME()", + "presto": "CURRENT_DATETIME()", + "hive": "CURRENT_DATETIME()", + "spark": "CURRENT_DATETIME()", + }, + ) + self.validate_all( + "current_time", + write={ + "bigquery": "CURRENT_TIME()", + "duckdb": "CURRENT_TIME()", + "presto": "CURRENT_TIME()", + "hive": "CURRENT_TIME()", + "spark": "CURRENT_TIME()", + }, + ) + self.validate_all( + "current_timestamp", + write={ + "bigquery": "CURRENT_TIMESTAMP()", + "duckdb": "CURRENT_TIMESTAMP()", + "postgres": "CURRENT_TIMESTAMP", + "presto": "CURRENT_TIMESTAMP()", + "hive": "CURRENT_TIMESTAMP()", + "spark": "CURRENT_TIMESTAMP()", + }, + ) + self.validate_all( + "current_timestamp()", + write={ + "bigquery": "CURRENT_TIMESTAMP()", + "duckdb": "CURRENT_TIMESTAMP()", + "postgres": "CURRENT_TIMESTAMP", + "presto": "CURRENT_TIMESTAMP()", + "hive": "CURRENT_TIMESTAMP()", + "spark": "CURRENT_TIMESTAMP()", + }, + ) + + self.validate_identity( + "SELECT ROW() OVER (y ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) FROM x WINDOW y AS (PARTITION BY CATEGORY)" + ) + + self.validate_identity( + "SELECT LAST_VALUE(a IGNORE NULLS) OVER y FROM x WINDOW y AS (PARTITION BY CATEGORY)", + ) + + self.validate_all( + "CREATE TABLE db.example_table (col_a struct)", + write={ + "bigquery": "CREATE TABLE db.example_table (col_a STRUCT)", + "duckdb": "CREATE TABLE db.example_table (col_a STRUCT)", + "presto": "CREATE TABLE db.example_table (col_a ROW(struct_col_a INTEGER, struct_col_b VARCHAR))", + "hive": "CREATE TABLE db.example_table (col_a STRUCT)", + "spark": "CREATE TABLE db.example_table (col_a STRUCT)", + }, + ) + self.validate_all( + "CREATE TABLE db.example_table (col_a STRUCT>)", + write={ + "bigquery": "CREATE TABLE db.example_table (col_a STRUCT>)", + "presto": "CREATE TABLE db.example_table (col_a ROW(struct_col_a BIGINT, struct_col_b ROW(nested_col_a VARCHAR, nested_col_b VARCHAR)))", + "hive": "CREATE TABLE db.example_table (col_a STRUCT>)", + "spark": "CREATE TABLE db.example_table (col_a STRUCT>)", + }, + ) + self.validate_all( + "SELECT * FROM a WHERE b IN UNNEST([1, 2, 3])", + write={ + "bigquery": "SELECT * FROM a WHERE b IN UNNEST([1, 2, 3])", + "mysql": "SELECT * FROM a WHERE b IN (SELECT UNNEST(ARRAY(1, 2, 3)))", + "presto": "SELECT * FROM a WHERE b IN (SELECT UNNEST(ARRAY[1, 2, 3]))", + "hive": "SELECT * FROM a WHERE b IN (SELECT UNNEST(ARRAY(1, 2, 3)))", + "spark": "SELECT * FROM a WHERE b IN (SELECT UNNEST(ARRAY(1, 2, 3)))", + }, + ) + + # Reference: https://cloud.google.com/bigquery/docs/reference/standard-sql/query-syntax#set_operators + with self.assertRaises(UnsupportedError): + transpile( + "SELECT * FROM a INTERSECT ALL SELECT * FROM b", + write="bigquery", + unsupported_level=ErrorLevel.RAISE, + ) + + with self.assertRaises(UnsupportedError): + transpile( + "SELECT * FROM a EXCEPT ALL SELECT * FROM b", + write="bigquery", + unsupported_level=ErrorLevel.RAISE, + ) + + with self.assertRaises(ParseError): + transpile("SELECT * FROM UNNEST(x) AS x(y)", read="bigquery") + + self.validate_all( + "DATE_SUB(CURRENT_DATE(), INTERVAL 1 DAY)", + write={ + "postgres": "CURRENT_DATE - INTERVAL '1' DAY", + }, + ) + self.validate_all( + "DATE_ADD(CURRENT_DATE(), INTERVAL 1 DAY)", + write={ + "bigquery": "DATE_ADD(CURRENT_DATE, INTERVAL 1 DAY)", + "duckdb": "CURRENT_DATE + INTERVAL 1 DAY", + "mysql": "DATE_ADD(CURRENT_DATE, INTERVAL 1 DAY)", + "postgres": "CURRENT_DATE + INTERVAL '1' DAY", + "presto": "DATE_ADD(DAY, 1, CURRENT_DATE)", + "hive": "DATE_ADD(CURRENT_DATE, 1)", + "spark": "DATE_ADD(CURRENT_DATE, 1)", + }, + ) + self.validate_all( + "CURRENT_DATE('UTC')", + write={ + "mysql": "CURRENT_DATE AT TIME ZONE 'UTC'", + "postgres": "CURRENT_DATE AT TIME ZONE 'UTC'", + }, + ) + self.validate_all( + "SELECT a FROM test WHERE a = 1 GROUP BY a HAVING a = 2 QUALIFY z ORDER BY a LIMIT 10", + write={ + "bigquery": "SELECT a FROM test WHERE a = 1 GROUP BY a HAVING a = 2 QUALIFY z ORDER BY a LIMIT 10", + "snowflake": "SELECT a FROM test WHERE a = 1 GROUP BY a HAVING a = 2 QUALIFY z ORDER BY a NULLS FIRST LIMIT 10", + }, + ) diff --git a/tests/dialects/test_clickhouse.py b/tests/dialects/test_clickhouse.py new file mode 100644 index 0000000..e5b1516 --- /dev/null +++ b/tests/dialects/test_clickhouse.py @@ -0,0 +1,25 @@ +from tests.dialects.test_dialect import Validator + + +class TestClickhouse(Validator): + dialect = "clickhouse" + + def test_clickhouse(self): + self.validate_identity("dictGet(x, 'y')") + self.validate_identity("SELECT * FROM x FINAL") + self.validate_identity("SELECT * FROM x AS y FINAL") + + self.validate_all( + "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname", + write={ + "clickhouse": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname", + "spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname NULLS LAST", + }, + ) + + self.validate_all( + "CAST(1 AS NULLABLE(Int64))", + write={ + "clickhouse": "CAST(1 AS Nullable(BIGINT))", + }, + ) diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py new file mode 100644 index 0000000..3993565 --- /dev/null +++ b/tests/dialects/test_dialect.py @@ -0,0 +1,981 @@ +import unittest + +from sqlglot import ( + Dialect, + Dialects, + ErrorLevel, + UnsupportedError, + parse_one, + transpile, +) + + +class Validator(unittest.TestCase): + dialect = None + + def validate(self, sql, target, **kwargs): + self.assertEqual(transpile(sql, **kwargs)[0], target) + + def validate_identity(self, sql): + self.assertEqual(transpile(sql, read=self.dialect, write=self.dialect)[0], sql) + + def validate_all(self, sql, read=None, write=None, pretty=False): + """ + Validate that: + 1. Everything in `read` transpiles to `sql` + 2. `sql` transpiles to everything in `write` + + Args: + sql (str): Main SQL expression + dialect (str): dialect of `sql` + read (dict): Mapping of dialect -> SQL + write (dict): Mapping of dialect -> SQL + """ + expression = parse_one(sql, read=self.dialect) + + for read_dialect, read_sql in (read or {}).items(): + with self.subTest(f"{read_dialect} -> {sql}"): + self.assertEqual( + parse_one(read_sql, read_dialect).sql( + self.dialect, unsupported_level=ErrorLevel.IGNORE + ), + sql, + ) + + for write_dialect, write_sql in (write or {}).items(): + with self.subTest(f"{sql} -> {write_dialect}"): + if write_sql is UnsupportedError: + with self.assertRaises(UnsupportedError): + expression.sql( + write_dialect, unsupported_level=ErrorLevel.RAISE + ) + else: + self.assertEqual( + expression.sql( + write_dialect, + unsupported_level=ErrorLevel.IGNORE, + pretty=pretty, + ), + write_sql, + ) + + +class TestDialect(Validator): + maxDiff = None + + def test_enum(self): + for dialect in Dialects: + self.assertIsNotNone(Dialect[dialect]) + self.assertIsNotNone(Dialect.get(dialect)) + self.assertIsNotNone(Dialect.get_or_raise(dialect)) + self.assertIsNotNone(Dialect[dialect.value]) + + def test_cast(self): + self.validate_all( + "CAST(a AS TEXT)", + write={ + "bigquery": "CAST(a AS STRING)", + "clickhouse": "CAST(a AS TEXT)", + "duckdb": "CAST(a AS TEXT)", + "mysql": "CAST(a AS TEXT)", + "hive": "CAST(a AS STRING)", + "oracle": "CAST(a AS CLOB)", + "postgres": "CAST(a AS TEXT)", + "presto": "CAST(a AS VARCHAR)", + "snowflake": "CAST(a AS TEXT)", + "spark": "CAST(a AS STRING)", + "starrocks": "CAST(a AS STRING)", + }, + ) + self.validate_all( + "CAST(a AS STRING)", + write={ + "bigquery": "CAST(a AS STRING)", + "duckdb": "CAST(a AS TEXT)", + "mysql": "CAST(a AS TEXT)", + "hive": "CAST(a AS STRING)", + "oracle": "CAST(a AS CLOB)", + "postgres": "CAST(a AS TEXT)", + "presto": "CAST(a AS VARCHAR)", + "snowflake": "CAST(a AS TEXT)", + "spark": "CAST(a AS STRING)", + "starrocks": "CAST(a AS STRING)", + }, + ) + self.validate_all( + "CAST(a AS VARCHAR)", + write={ + "bigquery": "CAST(a AS STRING)", + "duckdb": "CAST(a AS TEXT)", + "mysql": "CAST(a AS VARCHAR)", + "hive": "CAST(a AS STRING)", + "oracle": "CAST(a AS VARCHAR2)", + "postgres": "CAST(a AS VARCHAR)", + "presto": "CAST(a AS VARCHAR)", + "snowflake": "CAST(a AS VARCHAR)", + "spark": "CAST(a AS STRING)", + "starrocks": "CAST(a AS VARCHAR)", + }, + ) + self.validate_all( + "CAST(a AS VARCHAR(3))", + write={ + "bigquery": "CAST(a AS STRING(3))", + "duckdb": "CAST(a AS TEXT(3))", + "mysql": "CAST(a AS VARCHAR(3))", + "hive": "CAST(a AS VARCHAR(3))", + "oracle": "CAST(a AS VARCHAR2(3))", + "postgres": "CAST(a AS VARCHAR(3))", + "presto": "CAST(a AS VARCHAR(3))", + "snowflake": "CAST(a AS VARCHAR(3))", + "spark": "CAST(a AS VARCHAR(3))", + "starrocks": "CAST(a AS VARCHAR(3))", + }, + ) + self.validate_all( + "CAST(a AS SMALLINT)", + write={ + "bigquery": "CAST(a AS INT64)", + "duckdb": "CAST(a AS SMALLINT)", + "mysql": "CAST(a AS SMALLINT)", + "hive": "CAST(a AS SMALLINT)", + "oracle": "CAST(a AS NUMBER)", + "postgres": "CAST(a AS SMALLINT)", + "presto": "CAST(a AS SMALLINT)", + "snowflake": "CAST(a AS SMALLINT)", + "spark": "CAST(a AS SHORT)", + "sqlite": "CAST(a AS INTEGER)", + "starrocks": "CAST(a AS SMALLINT)", + }, + ) + self.validate_all( + "CAST(a AS DOUBLE)", + write={ + "bigquery": "CAST(a AS FLOAT64)", + "clickhouse": "CAST(a AS DOUBLE)", + "duckdb": "CAST(a AS DOUBLE)", + "mysql": "CAST(a AS DOUBLE)", + "hive": "CAST(a AS DOUBLE)", + "oracle": "CAST(a AS DOUBLE PRECISION)", + "postgres": "CAST(a AS DOUBLE PRECISION)", + "presto": "CAST(a AS DOUBLE)", + "snowflake": "CAST(a AS DOUBLE)", + "spark": "CAST(a AS DOUBLE)", + "starrocks": "CAST(a AS DOUBLE)", + }, + ) + self.validate_all( + "CAST(a AS TIMESTAMP)", write={"starrocks": "CAST(a AS DATETIME)"} + ) + self.validate_all( + "CAST(a AS TIMESTAMPTZ)", write={"starrocks": "CAST(a AS DATETIME)"} + ) + self.validate_all("CAST(a AS TINYINT)", write={"oracle": "CAST(a AS NUMBER)"}) + self.validate_all("CAST(a AS SMALLINT)", write={"oracle": "CAST(a AS NUMBER)"}) + self.validate_all("CAST(a AS BIGINT)", write={"oracle": "CAST(a AS NUMBER)"}) + self.validate_all("CAST(a AS INT)", write={"oracle": "CAST(a AS NUMBER)"}) + self.validate_all( + "CAST(a AS DECIMAL)", + read={"oracle": "CAST(a AS NUMBER)"}, + write={"oracle": "CAST(a AS NUMBER)"}, + ) + + def test_time(self): + self.validate_all( + "STR_TO_TIME(x, '%Y-%m-%dT%H:%M:%S')", + read={ + "duckdb": "STRPTIME(x, '%Y-%m-%dT%H:%M:%S')", + }, + write={ + "mysql": "STR_TO_DATE(x, '%Y-%m-%dT%H:%i:%S')", + "duckdb": "STRPTIME(x, '%Y-%m-%dT%H:%M:%S')", + "hive": "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'yyyy-MM-ddTHH:mm:ss')) AS TIMESTAMP)", + "presto": "DATE_PARSE(x, '%Y-%m-%dT%H:%i:%S')", + "spark": "TO_TIMESTAMP(x, 'yyyy-MM-ddTHH:mm:ss')", + }, + ) + self.validate_all( + "STR_TO_TIME('2020-01-01', '%Y-%m-%d')", + write={ + "duckdb": "STRPTIME('2020-01-01', '%Y-%m-%d')", + "hive": "CAST('2020-01-01' AS TIMESTAMP)", + "presto": "DATE_PARSE('2020-01-01', '%Y-%m-%d')", + "spark": "TO_TIMESTAMP('2020-01-01', 'yyyy-MM-dd')", + }, + ) + self.validate_all( + "STR_TO_TIME(x, '%y')", + write={ + "duckdb": "STRPTIME(x, '%y')", + "hive": "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'yy')) AS TIMESTAMP)", + "presto": "DATE_PARSE(x, '%y')", + "spark": "TO_TIMESTAMP(x, 'yy')", + }, + ) + self.validate_all( + "STR_TO_UNIX('2020-01-01', '%Y-%M-%d')", + write={ + "duckdb": "EPOCH(STRPTIME('2020-01-01', '%Y-%M-%d'))", + "hive": "UNIX_TIMESTAMP('2020-01-01', 'yyyy-mm-dd')", + "presto": "TO_UNIXTIME(DATE_PARSE('2020-01-01', '%Y-%i-%d'))", + }, + ) + self.validate_all( + "TIME_STR_TO_DATE('2020-01-01')", + write={ + "duckdb": "CAST('2020-01-01' AS DATE)", + "hive": "TO_DATE('2020-01-01')", + "presto": "DATE_PARSE('2020-01-01', '%Y-%m-%d %H:%i:%s')", + }, + ) + self.validate_all( + "TIME_STR_TO_TIME('2020-01-01')", + write={ + "duckdb": "CAST('2020-01-01' AS TIMESTAMP)", + "hive": "CAST('2020-01-01' AS TIMESTAMP)", + "presto": "DATE_PARSE('2020-01-01', '%Y-%m-%d %H:%i:%s')", + }, + ) + self.validate_all( + "TIME_STR_TO_UNIX('2020-01-01')", + write={ + "duckdb": "EPOCH(CAST('2020-01-01' AS TIMESTAMP))", + "hive": "UNIX_TIMESTAMP('2020-01-01')", + "presto": "TO_UNIXTIME(DATE_PARSE('2020-01-01', '%Y-%m-%d %H:%i:%S'))", + }, + ) + self.validate_all( + "TIME_TO_STR(x, '%Y-%m-%d')", + write={ + "duckdb": "STRFTIME(x, '%Y-%m-%d')", + "hive": "DATE_FORMAT(x, 'yyyy-MM-dd')", + "presto": "DATE_FORMAT(x, '%Y-%m-%d')", + }, + ) + self.validate_all( + "TIME_TO_TIME_STR(x)", + write={ + "duckdb": "CAST(x AS TEXT)", + "hive": "CAST(x AS STRING)", + "presto": "CAST(x AS VARCHAR)", + }, + ) + self.validate_all( + "TIME_TO_UNIX(x)", + write={ + "duckdb": "EPOCH(x)", + "hive": "UNIX_TIMESTAMP(x)", + "presto": "TO_UNIXTIME(x)", + }, + ) + self.validate_all( + "TS_OR_DS_TO_DATE_STR(x)", + write={ + "duckdb": "SUBSTRING(CAST(x AS TEXT), 1, 10)", + "hive": "SUBSTRING(CAST(x AS STRING), 1, 10)", + "presto": "SUBSTRING(CAST(x AS VARCHAR), 1, 10)", + }, + ) + self.validate_all( + "TS_OR_DS_TO_DATE(x)", + write={ + "duckdb": "CAST(x AS DATE)", + "hive": "TO_DATE(x)", + "presto": "CAST(SUBSTR(CAST(x AS VARCHAR), 1, 10) AS DATE)", + }, + ) + self.validate_all( + "TS_OR_DS_TO_DATE(x, '%-d')", + write={ + "duckdb": "CAST(STRPTIME(x, '%-d') AS DATE)", + "hive": "TO_DATE(x, 'd')", + "presto": "CAST(DATE_PARSE(x, '%e') AS DATE)", + "spark": "TO_DATE(x, 'd')", + }, + ) + self.validate_all( + "UNIX_TO_STR(x, y)", + write={ + "duckdb": "STRFTIME(TO_TIMESTAMP(CAST(x AS BIGINT)), y)", + "hive": "FROM_UNIXTIME(x, y)", + "presto": "DATE_FORMAT(FROM_UNIXTIME(x), y)", + }, + ) + self.validate_all( + "UNIX_TO_TIME(x)", + write={ + "duckdb": "TO_TIMESTAMP(CAST(x AS BIGINT))", + "hive": "FROM_UNIXTIME(x)", + "presto": "FROM_UNIXTIME(x)", + }, + ) + self.validate_all( + "UNIX_TO_TIME_STR(x)", + write={ + "duckdb": "CAST(TO_TIMESTAMP(CAST(x AS BIGINT)) AS TEXT)", + "hive": "FROM_UNIXTIME(x)", + "presto": "CAST(FROM_UNIXTIME(x) AS VARCHAR)", + }, + ) + self.validate_all( + "DATE_TO_DATE_STR(x)", + write={ + "duckdb": "CAST(x AS TEXT)", + "hive": "CAST(x AS STRING)", + "presto": "CAST(x AS VARCHAR)", + }, + ) + self.validate_all( + "DATE_TO_DI(x)", + write={ + "duckdb": "CAST(STRFTIME(x, '%Y%m%d') AS INT)", + "hive": "CAST(DATE_FORMAT(x, 'yyyyMMdd') AS INT)", + "presto": "CAST(DATE_FORMAT(x, '%Y%m%d') AS INT)", + }, + ) + self.validate_all( + "DI_TO_DATE(x)", + write={ + "duckdb": "CAST(STRPTIME(CAST(x AS TEXT), '%Y%m%d') AS DATE)", + "hive": "TO_DATE(CAST(x AS STRING), 'yyyyMMdd')", + "presto": "CAST(DATE_PARSE(CAST(x AS VARCHAR), '%Y%m%d') AS DATE)", + }, + ) + self.validate_all( + "TS_OR_DI_TO_DI(x)", + write={ + "duckdb": "CAST(SUBSTR(REPLACE(CAST(x AS TEXT), '-', ''), 1, 8) AS INT)", + "hive": "CAST(SUBSTR(REPLACE(CAST(x AS STRING), '-', ''), 1, 8) AS INT)", + "presto": "CAST(SUBSTR(REPLACE(CAST(x AS VARCHAR), '-', ''), 1, 8) AS INT)", + "spark": "CAST(SUBSTR(REPLACE(CAST(x AS STRING), '-', ''), 1, 8) AS INT)", + }, + ) + self.validate_all( + "DATE_ADD(x, 1, 'day')", + read={ + "mysql": "DATE_ADD(x, INTERVAL 1 DAY)", + "starrocks": "DATE_ADD(x, INTERVAL 1 DAY)", + }, + write={ + "bigquery": "DATE_ADD(x, INTERVAL 1 'day')", + "duckdb": "x + INTERVAL 1 day", + "hive": "DATE_ADD(x, 1)", + "mysql": "DATE_ADD(x, INTERVAL 1 DAY)", + "postgres": "x + INTERVAL '1' 'day'", + "presto": "DATE_ADD('day', 1, x)", + "spark": "DATE_ADD(x, 1)", + "starrocks": "DATE_ADD(x, INTERVAL 1 DAY)", + }, + ) + self.validate_all( + "DATE_ADD(x, y, 'day')", + write={ + "postgres": UnsupportedError, + }, + ) + self.validate_all( + "DATE_ADD(x, 1)", + write={ + "bigquery": "DATE_ADD(x, INTERVAL 1 'day')", + "duckdb": "x + INTERVAL 1 DAY", + "hive": "DATE_ADD(x, 1)", + "mysql": "DATE_ADD(x, INTERVAL 1 DAY)", + "presto": "DATE_ADD('day', 1, x)", + "spark": "DATE_ADD(x, 1)", + "starrocks": "DATE_ADD(x, INTERVAL 1 DAY)", + }, + ) + self.validate_all( + "DATE_TRUNC(x, 'day')", + write={ + "mysql": "DATE(x)", + "starrocks": "DATE(x)", + }, + ) + self.validate_all( + "DATE_TRUNC(x, 'week')", + write={ + "mysql": "STR_TO_DATE(CONCAT(YEAR(x), ' ', WEEK(x, 1), ' 1'), '%Y %u %w')", + "starrocks": "STR_TO_DATE(CONCAT(YEAR(x), ' ', WEEK(x, 1), ' 1'), '%Y %u %w')", + }, + ) + self.validate_all( + "DATE_TRUNC(x, 'month')", + write={ + "mysql": "STR_TO_DATE(CONCAT(YEAR(x), ' ', MONTH(x), ' 1'), '%Y %c %e')", + "starrocks": "STR_TO_DATE(CONCAT(YEAR(x), ' ', MONTH(x), ' 1'), '%Y %c %e')", + }, + ) + self.validate_all( + "DATE_TRUNC(x, 'quarter')", + write={ + "mysql": "STR_TO_DATE(CONCAT(YEAR(x), ' ', QUARTER(x) * 3 - 2, ' 1'), '%Y %c %e')", + "starrocks": "STR_TO_DATE(CONCAT(YEAR(x), ' ', QUARTER(x) * 3 - 2, ' 1'), '%Y %c %e')", + }, + ) + self.validate_all( + "DATE_TRUNC(x, 'year')", + write={ + "mysql": "STR_TO_DATE(CONCAT(YEAR(x), ' 1 1'), '%Y %c %e')", + "starrocks": "STR_TO_DATE(CONCAT(YEAR(x), ' 1 1'), '%Y %c %e')", + }, + ) + self.validate_all( + "DATE_TRUNC(x, 'millenium')", + write={ + "mysql": UnsupportedError, + "starrocks": UnsupportedError, + }, + ) + self.validate_all( + "STR_TO_DATE(x, '%Y-%m-%dT%H:%M:%S')", + read={ + "mysql": "STR_TO_DATE(x, '%Y-%m-%dT%H:%i:%S')", + "starrocks": "STR_TO_DATE(x, '%Y-%m-%dT%H:%i:%S')", + }, + write={ + "mysql": "STR_TO_DATE(x, '%Y-%m-%dT%H:%i:%S')", + "starrocks": "STR_TO_DATE(x, '%Y-%m-%dT%H:%i:%S')", + "hive": "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'yyyy-MM-ddTHH:mm:ss')) AS DATE)", + "presto": "CAST(DATE_PARSE(x, '%Y-%m-%dT%H:%i:%S') AS DATE)", + "spark": "TO_DATE(x, 'yyyy-MM-ddTHH:mm:ss')", + }, + ) + self.validate_all( + "STR_TO_DATE(x, '%Y-%m-%d')", + write={ + "mysql": "STR_TO_DATE(x, '%Y-%m-%d')", + "starrocks": "STR_TO_DATE(x, '%Y-%m-%d')", + "hive": "CAST(x AS DATE)", + "presto": "CAST(DATE_PARSE(x, '%Y-%m-%d') AS DATE)", + "spark": "TO_DATE(x)", + }, + ) + self.validate_all( + "DATE_STR_TO_DATE(x)", + write={ + "duckdb": "CAST(x AS DATE)", + "hive": "TO_DATE(x)", + "presto": "CAST(DATE_PARSE(x, '%Y-%m-%d') AS DATE)", + "spark": "TO_DATE(x)", + }, + ) + self.validate_all( + "TS_OR_DS_ADD('2021-02-01', 1, 'DAY')", + write={ + "duckdb": "CAST('2021-02-01' AS DATE) + INTERVAL 1 DAY", + "hive": "DATE_ADD('2021-02-01', 1)", + "presto": "DATE_ADD('DAY', 1, DATE_PARSE(SUBSTR('2021-02-01', 1, 10), '%Y-%m-%d'))", + "spark": "DATE_ADD('2021-02-01', 1)", + }, + ) + self.validate_all( + "DATE_ADD(CAST('2020-01-01' AS DATE), 1)", + write={ + "duckdb": "CAST('2020-01-01' AS DATE) + INTERVAL 1 DAY", + "hive": "DATE_ADD(CAST('2020-01-01' AS DATE), 1)", + "presto": "DATE_ADD('day', 1, CAST('2020-01-01' AS DATE))", + "spark": "DATE_ADD(CAST('2020-01-01' AS DATE), 1)", + }, + ) + + for unit in ("DAY", "MONTH", "YEAR"): + self.validate_all( + f"{unit}(x)", + read={ + dialect: f"{unit}(x)" + for dialect in ( + "bigquery", + "duckdb", + "mysql", + "presto", + "starrocks", + ) + }, + write={ + dialect: f"{unit}(x)" + for dialect in ( + "bigquery", + "duckdb", + "mysql", + "presto", + "hive", + "spark", + "starrocks", + ) + }, + ) + + def test_array(self): + self.validate_all( + "ARRAY(0, 1, 2)", + write={ + "bigquery": "[0, 1, 2]", + "duckdb": "LIST_VALUE(0, 1, 2)", + "presto": "ARRAY[0, 1, 2]", + "spark": "ARRAY(0, 1, 2)", + }, + ) + self.validate_all( + "ARRAY_SIZE(x)", + write={ + "bigquery": "ARRAY_LENGTH(x)", + "duckdb": "ARRAY_LENGTH(x)", + "presto": "CARDINALITY(x)", + "spark": "SIZE(x)", + }, + ) + self.validate_all( + "ARRAY_SUM(ARRAY(1, 2))", + write={ + "trino": "REDUCE(ARRAY[1, 2], 0, (acc, x) -> acc + x, acc -> acc)", + "duckdb": "LIST_SUM(LIST_VALUE(1, 2))", + "hive": "ARRAY_SUM(ARRAY(1, 2))", + "presto": "ARRAY_SUM(ARRAY[1, 2])", + "spark": "AGGREGATE(ARRAY(1, 2), 0, (acc, x) -> acc + x, acc -> acc)", + }, + ) + self.validate_all( + "REDUCE(x, 0, (acc, x) -> acc + x, acc -> acc)", + write={ + "trino": "REDUCE(x, 0, (acc, x) -> acc + x, acc -> acc)", + "duckdb": "REDUCE(x, 0, (acc, x) -> acc + x, acc -> acc)", + "hive": "REDUCE(x, 0, (acc, x) -> acc + x, acc -> acc)", + "presto": "REDUCE(x, 0, (acc, x) -> acc + x, acc -> acc)", + "spark": "AGGREGATE(x, 0, (acc, x) -> acc + x, acc -> acc)", + }, + ) + + def test_order_by(self): + self.validate_all( + "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname", + write={ + "bigquery": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname", + "duckdb": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname", + "presto": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname NULLS FIRST", + "hive": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname", + "spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname", + }, + ) + + def test_json(self): + self.validate_all( + "JSON_EXTRACT(x, 'y')", + read={ + "postgres": "x->'y'", + "presto": "JSON_EXTRACT(x, 'y')", + }, + write={ + "postgres": "x->'y'", + "presto": "JSON_EXTRACT(x, 'y')", + }, + ) + self.validate_all( + "JSON_EXTRACT_SCALAR(x, 'y')", + read={ + "postgres": "x->>'y'", + "presto": "JSON_EXTRACT_SCALAR(x, 'y')", + }, + write={ + "postgres": "x->>'y'", + "presto": "JSON_EXTRACT_SCALAR(x, 'y')", + }, + ) + self.validate_all( + "JSONB_EXTRACT(x, 'y')", + read={ + "postgres": "x#>'y'", + }, + write={ + "postgres": "x#>'y'", + }, + ) + self.validate_all( + "JSONB_EXTRACT_SCALAR(x, 'y')", + read={ + "postgres": "x#>>'y'", + }, + write={ + "postgres": "x#>>'y'", + }, + ) + + def test_cross_join(self): + self.validate_all( + "SELECT a FROM x CROSS JOIN UNNEST(y) AS t (a)", + write={ + "presto": "SELECT a FROM x CROSS JOIN UNNEST(y) AS t(a)", + "spark": "SELECT a FROM x LATERAL VIEW EXPLODE(y) t AS a", + }, + ) + self.validate_all( + "SELECT a, b FROM x CROSS JOIN UNNEST(y, z) AS t (a, b)", + write={ + "presto": "SELECT a, b FROM x CROSS JOIN UNNEST(y, z) AS t(a, b)", + "spark": "SELECT a, b FROM x LATERAL VIEW EXPLODE(y) t AS a LATERAL VIEW EXPLODE(z) t AS b", + }, + ) + self.validate_all( + "SELECT a FROM x CROSS JOIN UNNEST(y) WITH ORDINALITY AS t (a)", + write={ + "presto": "SELECT a FROM x CROSS JOIN UNNEST(y) WITH ORDINALITY AS t(a)", + "spark": "SELECT a FROM x LATERAL VIEW POSEXPLODE(y) t AS a", + }, + ) + + def test_set_operators(self): + self.validate_all( + "SELECT * FROM a UNION SELECT * FROM b", + read={ + "bigquery": "SELECT * FROM a UNION DISTINCT SELECT * FROM b", + "duckdb": "SELECT * FROM a UNION SELECT * FROM b", + "presto": "SELECT * FROM a UNION SELECT * FROM b", + "spark": "SELECT * FROM a UNION SELECT * FROM b", + }, + write={ + "bigquery": "SELECT * FROM a UNION DISTINCT SELECT * FROM b", + "duckdb": "SELECT * FROM a UNION SELECT * FROM b", + "presto": "SELECT * FROM a UNION SELECT * FROM b", + "spark": "SELECT * FROM a UNION SELECT * FROM b", + }, + ) + self.validate_all( + "SELECT * FROM a UNION ALL SELECT * FROM b", + read={ + "bigquery": "SELECT * FROM a UNION ALL SELECT * FROM b", + "duckdb": "SELECT * FROM a UNION ALL SELECT * FROM b", + "presto": "SELECT * FROM a UNION ALL SELECT * FROM b", + "spark": "SELECT * FROM a UNION ALL SELECT * FROM b", + }, + write={ + "bigquery": "SELECT * FROM a UNION ALL SELECT * FROM b", + "duckdb": "SELECT * FROM a UNION ALL SELECT * FROM b", + "presto": "SELECT * FROM a UNION ALL SELECT * FROM b", + "spark": "SELECT * FROM a UNION ALL SELECT * FROM b", + }, + ) + self.validate_all( + "SELECT * FROM a INTERSECT SELECT * FROM b", + read={ + "bigquery": "SELECT * FROM a INTERSECT DISTINCT SELECT * FROM b", + "duckdb": "SELECT * FROM a INTERSECT SELECT * FROM b", + "presto": "SELECT * FROM a INTERSECT SELECT * FROM b", + "spark": "SELECT * FROM a INTERSECT SELECT * FROM b", + }, + write={ + "bigquery": "SELECT * FROM a INTERSECT DISTINCT SELECT * FROM b", + "duckdb": "SELECT * FROM a INTERSECT SELECT * FROM b", + "presto": "SELECT * FROM a INTERSECT SELECT * FROM b", + "spark": "SELECT * FROM a INTERSECT SELECT * FROM b", + }, + ) + self.validate_all( + "SELECT * FROM a EXCEPT SELECT * FROM b", + read={ + "bigquery": "SELECT * FROM a EXCEPT DISTINCT SELECT * FROM b", + "duckdb": "SELECT * FROM a EXCEPT SELECT * FROM b", + "presto": "SELECT * FROM a EXCEPT SELECT * FROM b", + "spark": "SELECT * FROM a EXCEPT SELECT * FROM b", + }, + write={ + "bigquery": "SELECT * FROM a EXCEPT DISTINCT SELECT * FROM b", + "duckdb": "SELECT * FROM a EXCEPT SELECT * FROM b", + "presto": "SELECT * FROM a EXCEPT SELECT * FROM b", + "spark": "SELECT * FROM a EXCEPT SELECT * FROM b", + }, + ) + self.validate_all( + "SELECT * FROM a UNION DISTINCT SELECT * FROM b", + write={ + "bigquery": "SELECT * FROM a UNION DISTINCT SELECT * FROM b", + "duckdb": "SELECT * FROM a UNION SELECT * FROM b", + "presto": "SELECT * FROM a UNION SELECT * FROM b", + "spark": "SELECT * FROM a UNION SELECT * FROM b", + }, + ) + self.validate_all( + "SELECT * FROM a INTERSECT DISTINCT SELECT * FROM b", + write={ + "bigquery": "SELECT * FROM a INTERSECT DISTINCT SELECT * FROM b", + "duckdb": "SELECT * FROM a INTERSECT SELECT * FROM b", + "presto": "SELECT * FROM a INTERSECT SELECT * FROM b", + "spark": "SELECT * FROM a INTERSECT SELECT * FROM b", + }, + ) + self.validate_all( + "SELECT * FROM a INTERSECT ALL SELECT * FROM b", + write={ + "bigquery": "SELECT * FROM a INTERSECT ALL SELECT * FROM b", + "duckdb": "SELECT * FROM a INTERSECT ALL SELECT * FROM b", + "presto": "SELECT * FROM a INTERSECT ALL SELECT * FROM b", + "spark": "SELECT * FROM a INTERSECT ALL SELECT * FROM b", + }, + ) + self.validate_all( + "SELECT * FROM a EXCEPT DISTINCT SELECT * FROM b", + write={ + "bigquery": "SELECT * FROM a EXCEPT DISTINCT SELECT * FROM b", + "duckdb": "SELECT * FROM a EXCEPT SELECT * FROM b", + "presto": "SELECT * FROM a EXCEPT SELECT * FROM b", + "spark": "SELECT * FROM a EXCEPT SELECT * FROM b", + }, + ) + self.validate_all( + "SELECT * FROM a EXCEPT ALL SELECT * FROM b", + read={ + "bigquery": "SELECT * FROM a EXCEPT ALL SELECT * FROM b", + "duckdb": "SELECT * FROM a EXCEPT ALL SELECT * FROM b", + "presto": "SELECT * FROM a EXCEPT ALL SELECT * FROM b", + "spark": "SELECT * FROM a EXCEPT ALL SELECT * FROM b", + }, + ) + + def test_operators(self): + self.validate_all( + "x ILIKE '%y'", + read={ + "clickhouse": "x ILIKE '%y'", + "duckdb": "x ILIKE '%y'", + "postgres": "x ILIKE '%y'", + "snowflake": "x ILIKE '%y'", + }, + write={ + "bigquery": "LOWER(x) LIKE '%y'", + "clickhouse": "x ILIKE '%y'", + "duckdb": "x ILIKE '%y'", + "hive": "LOWER(x) LIKE '%y'", + "mysql": "LOWER(x) LIKE '%y'", + "oracle": "LOWER(x) LIKE '%y'", + "postgres": "x ILIKE '%y'", + "presto": "LOWER(x) LIKE '%y'", + "snowflake": "x ILIKE '%y'", + "spark": "LOWER(x) LIKE '%y'", + "sqlite": "LOWER(x) LIKE '%y'", + "starrocks": "LOWER(x) LIKE '%y'", + "trino": "LOWER(x) LIKE '%y'", + }, + ) + self.validate_all( + "SELECT * FROM a ORDER BY col_a NULLS LAST", + write={ + "mysql": UnsupportedError, + "starrocks": UnsupportedError, + }, + ) + self.validate_all( + "STR_POSITION(x, 'a')", + write={ + "duckdb": "STRPOS(x, 'a')", + "presto": "STRPOS(x, 'a')", + "spark": "LOCATE('a', x)", + }, + ) + self.validate_all( + "CONCAT_WS('-', 'a', 'b')", + write={ + "duckdb": "CONCAT_WS('-', 'a', 'b')", + "presto": "ARRAY_JOIN(ARRAY['a', 'b'], '-')", + "hive": "CONCAT_WS('-', 'a', 'b')", + "spark": "CONCAT_WS('-', 'a', 'b')", + }, + ) + + self.validate_all( + "CONCAT_WS('-', x)", + write={ + "duckdb": "CONCAT_WS('-', x)", + "presto": "ARRAY_JOIN(x, '-')", + "hive": "CONCAT_WS('-', x)", + "spark": "CONCAT_WS('-', x)", + }, + ) + self.validate_all( + "IF(x > 1, 1, 0)", + write={ + "duckdb": "CASE WHEN x > 1 THEN 1 ELSE 0 END", + "presto": "IF(x > 1, 1, 0)", + "hive": "IF(x > 1, 1, 0)", + "spark": "IF(x > 1, 1, 0)", + "tableau": "IF x > 1 THEN 1 ELSE 0 END", + }, + ) + self.validate_all( + "CASE WHEN 1 THEN x ELSE 0 END", + write={ + "duckdb": "CASE WHEN 1 THEN x ELSE 0 END", + "presto": "CASE WHEN 1 THEN x ELSE 0 END", + "hive": "CASE WHEN 1 THEN x ELSE 0 END", + "spark": "CASE WHEN 1 THEN x ELSE 0 END", + "tableau": "CASE WHEN 1 THEN x ELSE 0 END", + }, + ) + self.validate_all( + "x[y]", + write={ + "duckdb": "x[y]", + "presto": "x[y]", + "hive": "x[y]", + "spark": "x[y]", + }, + ) + self.validate_all( + """'["x"]'""", + write={ + "duckdb": """'["x"]'""", + "presto": """'["x"]'""", + "hive": """'["x"]'""", + "spark": """'["x"]'""", + }, + ) + + self.validate_all( + 'true or null as "foo"', + write={ + "bigquery": "TRUE OR NULL AS `foo`", + "duckdb": 'TRUE OR NULL AS "foo"', + "presto": 'TRUE OR NULL AS "foo"', + "hive": "TRUE OR NULL AS `foo`", + "spark": "TRUE OR NULL AS `foo`", + }, + ) + self.validate_all( + "SELECT IF(COALESCE(bar, 0) = 1, TRUE, FALSE) as foo FROM baz", + write={ + "bigquery": "SELECT CASE WHEN COALESCE(bar, 0) = 1 THEN TRUE ELSE FALSE END AS foo FROM baz", + "duckdb": "SELECT CASE WHEN COALESCE(bar, 0) = 1 THEN TRUE ELSE FALSE END AS foo FROM baz", + "presto": "SELECT IF(COALESCE(bar, 0) = 1, TRUE, FALSE) AS foo FROM baz", + "hive": "SELECT IF(COALESCE(bar, 0) = 1, TRUE, FALSE) AS foo FROM baz", + "spark": "SELECT IF(COALESCE(bar, 0) = 1, TRUE, FALSE) AS foo FROM baz", + }, + ) + self.validate_all( + "LEVENSHTEIN(col1, col2)", + write={ + "duckdb": "LEVENSHTEIN(col1, col2)", + "presto": "LEVENSHTEIN_DISTANCE(col1, col2)", + "hive": "LEVENSHTEIN(col1, col2)", + "spark": "LEVENSHTEIN(col1, col2)", + }, + ) + self.validate_all( + "LEVENSHTEIN(coalesce(col1, col2), coalesce(col2, col1))", + write={ + "duckdb": "LEVENSHTEIN(COALESCE(col1, col2), COALESCE(col2, col1))", + "presto": "LEVENSHTEIN_DISTANCE(COALESCE(col1, col2), COALESCE(col2, col1))", + "hive": "LEVENSHTEIN(COALESCE(col1, col2), COALESCE(col2, col1))", + "spark": "LEVENSHTEIN(COALESCE(col1, col2), COALESCE(col2, col1))", + }, + ) + self.validate_all( + "ARRAY_FILTER(the_array, x -> x > 0)", + write={ + "presto": "FILTER(the_array, x -> x > 0)", + "hive": "FILTER(the_array, x -> x > 0)", + "spark": "FILTER(the_array, x -> x > 0)", + }, + ) + self.validate_all( + "SELECT a AS b FROM x GROUP BY b", + write={ + "duckdb": "SELECT a AS b FROM x GROUP BY b", + "presto": "SELECT a AS b FROM x GROUP BY 1", + "hive": "SELECT a AS b FROM x GROUP BY 1", + "oracle": "SELECT a AS b FROM x GROUP BY 1", + "spark": "SELECT a AS b FROM x GROUP BY 1", + }, + ) + self.validate_all( + "SELECT x FROM y LIMIT 10", + write={ + "sqlite": "SELECT x FROM y LIMIT 10", + "oracle": "SELECT x FROM y FETCH FIRST 10 ROWS ONLY", + }, + ) + self.validate_all( + "SELECT x FROM y LIMIT 10 OFFSET 5", + write={ + "sqlite": "SELECT x FROM y LIMIT 10 OFFSET 5", + "oracle": "SELECT x FROM y OFFSET 5 ROWS FETCH FIRST 10 ROWS ONLY", + }, + ) + self.validate_all( + "SELECT x FROM y OFFSET 10 FETCH FIRST 3 ROWS ONLY", + write={ + "oracle": "SELECT x FROM y OFFSET 10 ROWS FETCH FIRST 3 ROWS ONLY", + }, + ) + self.validate_all( + "SELECT x FROM y OFFSET 10 ROWS FETCH FIRST 3 ROWS ONLY", + write={ + "oracle": "SELECT x FROM y OFFSET 10 ROWS FETCH FIRST 3 ROWS ONLY", + }, + ) + self.validate_all( + '"x" + "y"', + read={ + "clickhouse": '`x` + "y"', + "sqlite": '`x` + "y"', + }, + ) + self.validate_all( + "[1, 2]", + write={ + "bigquery": "[1, 2]", + "clickhouse": "[1, 2]", + }, + ) + self.validate_all( + "SELECT * FROM VALUES ('x'), ('y') AS t(z)", + write={ + "spark": "SELECT * FROM (VALUES ('x'), ('y')) AS t(z)", + }, + ) + self.validate_all( + "CREATE TABLE t (c CHAR, nc NCHAR, v1 VARCHAR, v2 VARCHAR2, nv NVARCHAR, nv2 NVARCHAR2)", + write={ + "hive": "CREATE TABLE t (c CHAR, nc CHAR, v1 STRING, v2 STRING, nv STRING, nv2 STRING)", + "oracle": "CREATE TABLE t (c CHAR, nc CHAR, v1 VARCHAR2, v2 VARCHAR2, nv NVARCHAR2, nv2 NVARCHAR2)", + "postgres": "CREATE TABLE t (c CHAR, nc CHAR, v1 VARCHAR, v2 VARCHAR, nv VARCHAR, nv2 VARCHAR)", + "sqlite": "CREATE TABLE t (c TEXT, nc TEXT, v1 TEXT, v2 TEXT, nv TEXT, nv2 TEXT)", + }, + ) + self.validate_all( + "POWER(1.2, 3.4)", + read={ + "hive": "pow(1.2, 3.4)", + "postgres": "power(1.2, 3.4)", + }, + ) + self.validate_all( + "CREATE INDEX my_idx ON tbl (a, b)", + read={ + "hive": "CREATE INDEX my_idx ON TABLE tbl (a, b)", + "sqlite": "CREATE INDEX my_idx ON tbl (a, b)", + }, + write={ + "hive": "CREATE INDEX my_idx ON TABLE tbl (a, b)", + "postgres": "CREATE INDEX my_idx ON tbl (a, b)", + "sqlite": "CREATE INDEX my_idx ON tbl (a, b)", + }, + ) + self.validate_all( + "CREATE UNIQUE INDEX my_idx ON tbl (a, b)", + read={ + "hive": "CREATE UNIQUE INDEX my_idx ON TABLE tbl (a, b)", + "sqlite": "CREATE UNIQUE INDEX my_idx ON tbl (a, b)", + }, + write={ + "hive": "CREATE UNIQUE INDEX my_idx ON TABLE tbl (a, b)", + "postgres": "CREATE UNIQUE INDEX my_idx ON tbl (a, b)", + "sqlite": "CREATE UNIQUE INDEX my_idx ON tbl (a, b)", + }, + ) + self.validate_all( + "CREATE TABLE t (b1 BINARY, b2 BINARY(1024), c1 TEXT, c2 TEXT(1024))", + write={ + "hive": "CREATE TABLE t (b1 BINARY, b2 BINARY(1024), c1 STRING, c2 STRING(1024))", + "oracle": "CREATE TABLE t (b1 BLOB, b2 BLOB(1024), c1 CLOB, c2 CLOB(1024))", + "postgres": "CREATE TABLE t (b1 BYTEA, b2 BYTEA(1024), c1 TEXT, c2 TEXT(1024))", + "sqlite": "CREATE TABLE t (b1 BLOB, b2 BLOB(1024), c1 TEXT, c2 TEXT(1024))", + }, + ) diff --git a/tests/dialects/test_duckdb.py b/tests/dialects/test_duckdb.py new file mode 100644 index 0000000..501301f --- /dev/null +++ b/tests/dialects/test_duckdb.py @@ -0,0 +1,249 @@ +from tests.dialects.test_dialect import Validator + + +class TestDuckDB(Validator): + dialect = "duckdb" + + def test_time(self): + self.validate_all( + "EPOCH(x)", + read={ + "presto": "TO_UNIXTIME(x)", + }, + write={ + "bigquery": "TIME_TO_UNIX(x)", + "duckdb": "EPOCH(x)", + "presto": "TO_UNIXTIME(x)", + "spark": "UNIX_TIMESTAMP(x)", + }, + ) + self.validate_all( + "EPOCH_MS(x)", + write={ + "bigquery": "UNIX_TO_TIME(x / 1000)", + "duckdb": "TO_TIMESTAMP(CAST(x / 1000 AS BIGINT))", + "presto": "FROM_UNIXTIME(x / 1000)", + "spark": "FROM_UNIXTIME(x / 1000)", + }, + ) + self.validate_all( + "STRFTIME(x, '%y-%-m-%S')", + write={ + "bigquery": "TIME_TO_STR(x, '%y-%-m-%S')", + "duckdb": "STRFTIME(x, '%y-%-m-%S')", + "postgres": "TO_CHAR(x, 'YY-FMMM-SS')", + "presto": "DATE_FORMAT(x, '%y-%c-%S')", + "spark": "DATE_FORMAT(x, 'yy-M-ss')", + }, + ) + self.validate_all( + "STRFTIME(x, '%Y-%m-%d %H:%M:%S')", + write={ + "duckdb": "STRFTIME(x, '%Y-%m-%d %H:%M:%S')", + "presto": "DATE_FORMAT(x, '%Y-%m-%d %H:%i:%S')", + "hive": "DATE_FORMAT(x, 'yyyy-MM-dd HH:mm:ss')", + }, + ) + self.validate_all( + "STRPTIME(x, '%y-%-m')", + write={ + "bigquery": "STR_TO_TIME(x, '%y-%-m')", + "duckdb": "STRPTIME(x, '%y-%-m')", + "presto": "DATE_PARSE(x, '%y-%c')", + "hive": "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'yy-M')) AS TIMESTAMP)", + "spark": "TO_TIMESTAMP(x, 'yy-M')", + }, + ) + self.validate_all( + "TO_TIMESTAMP(x)", + write={ + "duckdb": "CAST(x AS TIMESTAMP)", + "presto": "DATE_PARSE(x, '%Y-%m-%d %H:%i:%s')", + "hive": "CAST(x AS TIMESTAMP)", + }, + ) + + def test_duckdb(self): + self.validate_all( + "LIST_VALUE(0, 1, 2)", + write={ + "bigquery": "[0, 1, 2]", + "duckdb": "LIST_VALUE(0, 1, 2)", + "presto": "ARRAY[0, 1, 2]", + "spark": "ARRAY(0, 1, 2)", + }, + ) + self.validate_all( + "REGEXP_MATCHES(x, y)", + write={ + "duckdb": "REGEXP_MATCHES(x, y)", + "presto": "REGEXP_LIKE(x, y)", + "hive": "x RLIKE y", + "spark": "x RLIKE y", + }, + ) + self.validate_all( + "STR_SPLIT(x, 'a')", + write={ + "duckdb": "STR_SPLIT(x, 'a')", + "presto": "SPLIT(x, 'a')", + "hive": "SPLIT(x, CONCAT('\\\\Q', 'a'))", + "spark": "SPLIT(x, CONCAT('\\\\Q', 'a'))", + }, + ) + self.validate_all( + "STRING_TO_ARRAY(x, 'a')", + write={ + "duckdb": "STR_SPLIT(x, 'a')", + "presto": "SPLIT(x, 'a')", + "hive": "SPLIT(x, CONCAT('\\\\Q', 'a'))", + "spark": "SPLIT(x, CONCAT('\\\\Q', 'a'))", + }, + ) + self.validate_all( + "STR_SPLIT_REGEX(x, 'a')", + write={ + "duckdb": "STR_SPLIT_REGEX(x, 'a')", + "presto": "REGEXP_SPLIT(x, 'a')", + "hive": "SPLIT(x, 'a')", + "spark": "SPLIT(x, 'a')", + }, + ) + self.validate_all( + "STRUCT_EXTRACT(x, 'abc')", + write={ + "duckdb": "STRUCT_EXTRACT(x, 'abc')", + "presto": 'x."abc"', + "hive": "x.`abc`", + "spark": "x.`abc`", + }, + ) + self.validate_all( + "STRUCT_EXTRACT(STRUCT_EXTRACT(x, 'y'), 'abc')", + write={ + "duckdb": "STRUCT_EXTRACT(STRUCT_EXTRACT(x, 'y'), 'abc')", + "presto": 'x."y"."abc"', + "hive": "x.`y`.`abc`", + "spark": "x.`y`.`abc`", + }, + ) + + self.validate_all( + "QUANTILE(x, 0.5)", + write={ + "duckdb": "QUANTILE(x, 0.5)", + "presto": "APPROX_PERCENTILE(x, 0.5)", + "hive": "PERCENTILE(x, 0.5)", + "spark": "PERCENTILE(x, 0.5)", + }, + ) + + self.validate_all( + "CAST(x AS DATE)", + write={ + "duckdb": "CAST(x AS DATE)", + "": "CAST(x AS DATE)", + }, + ) + self.validate_all( + "UNNEST(x)", + read={ + "spark": "EXPLODE(x)", + }, + write={ + "duckdb": "UNNEST(x)", + "spark": "EXPLODE(x)", + }, + ) + + self.validate_all( + "1d", + write={ + "duckdb": "1 AS d", + "spark": "1 AS d", + }, + ) + self.validate_all( + "CAST(1 AS DOUBLE)", + read={ + "hive": "1d", + "spark": "1d", + }, + ) + self.validate_all( + "POWER(CAST(2 AS SMALLINT), 3)", + read={ + "hive": "POW(2S, 3)", + "spark": "POW(2S, 3)", + }, + ) + self.validate_all( + "LIST_SUM(LIST_VALUE(1, 2))", + read={ + "spark": "ARRAY_SUM(ARRAY(1, 2))", + }, + ) + self.validate_all( + "IF(y <> 0, x / y, NULL)", + read={ + "bigquery": "SAFE_DIVIDE(x, y)", + }, + ) + self.validate_all( + "STRUCT_PACK(x := 1, y := '2')", + write={ + "duckdb": "STRUCT_PACK(x := 1, y := '2')", + "spark": "STRUCT(x = 1, y = '2')", + }, + ) + self.validate_all( + "ARRAY_SORT(x)", + write={ + "duckdb": "ARRAY_SORT(x)", + "presto": "ARRAY_SORT(x)", + "hive": "SORT_ARRAY(x)", + "spark": "SORT_ARRAY(x)", + }, + ) + self.validate_all( + "ARRAY_REVERSE_SORT(x)", + write={ + "duckdb": "ARRAY_REVERSE_SORT(x)", + "presto": "ARRAY_SORT(x, (a, b) -> CASE WHEN a < b THEN 1 WHEN a > b THEN -1 ELSE 0 END)", + "hive": "SORT_ARRAY(x, FALSE)", + "spark": "SORT_ARRAY(x, FALSE)", + }, + ) + self.validate_all( + "LIST_REVERSE_SORT(x)", + write={ + "duckdb": "ARRAY_REVERSE_SORT(x)", + "presto": "ARRAY_SORT(x, (a, b) -> CASE WHEN a < b THEN 1 WHEN a > b THEN -1 ELSE 0 END)", + "hive": "SORT_ARRAY(x, FALSE)", + "spark": "SORT_ARRAY(x, FALSE)", + }, + ) + self.validate_all( + "LIST_SORT(x)", + write={ + "duckdb": "ARRAY_SORT(x)", + "presto": "ARRAY_SORT(x)", + "hive": "SORT_ARRAY(x)", + "spark": "SORT_ARRAY(x)", + }, + ) + self.validate_all( + "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname", + write={ + "duckdb": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname", + }, + ) + self.validate_all( + "MONTH('2021-03-01')", + write={ + "duckdb": "MONTH('2021-03-01')", + "presto": "MONTH('2021-03-01')", + "hive": "MONTH('2021-03-01')", + "spark": "MONTH('2021-03-01')", + }, + ) diff --git a/tests/dialects/test_hive.py b/tests/dialects/test_hive.py new file mode 100644 index 0000000..eccd75a --- /dev/null +++ b/tests/dialects/test_hive.py @@ -0,0 +1,541 @@ +from tests.dialects.test_dialect import Validator + + +class TestHive(Validator): + dialect = "hive" + + def test_bits(self): + self.validate_all( + "x & 1", + write={ + "duckdb": "x & 1", + "presto": "BITWISE_AND(x, 1)", + "hive": "x & 1", + "spark": "x & 1", + }, + ) + self.validate_all( + "~x", + write={ + "duckdb": "~x", + "presto": "BITWISE_NOT(x)", + "hive": "~x", + "spark": "~x", + }, + ) + self.validate_all( + "x | 1", + write={ + "duckdb": "x | 1", + "presto": "BITWISE_OR(x, 1)", + "hive": "x | 1", + "spark": "x | 1", + }, + ) + self.validate_all( + "x << 1", + read={ + "spark": "SHIFTLEFT(x, 1)", + }, + write={ + "duckdb": "x << 1", + "presto": "BITWISE_ARITHMETIC_SHIFT_LEFT(x, 1)", + "hive": "x << 1", + "spark": "SHIFTLEFT(x, 1)", + }, + ) + self.validate_all( + "x >> 1", + read={ + "spark": "SHIFTRIGHT(x, 1)", + }, + write={ + "duckdb": "x >> 1", + "presto": "BITWISE_ARITHMETIC_SHIFT_RIGHT(x, 1)", + "hive": "x >> 1", + "spark": "SHIFTRIGHT(x, 1)", + }, + ) + self.validate_all( + "x & 1 > 0", + write={ + "duckdb": "x & 1 > 0", + "presto": "BITWISE_AND(x, 1) > 0", + "hive": "x & 1 > 0", + "spark": "x & 1 > 0", + }, + ) + + def test_cast(self): + self.validate_all( + "1s", + write={ + "duckdb": "CAST(1 AS SMALLINT)", + "presto": "CAST(1 AS SMALLINT)", + "hive": "CAST(1 AS SMALLINT)", + "spark": "CAST(1 AS SHORT)", + }, + ) + self.validate_all( + "1S", + write={ + "duckdb": "CAST(1 AS SMALLINT)", + "presto": "CAST(1 AS SMALLINT)", + "hive": "CAST(1 AS SMALLINT)", + "spark": "CAST(1 AS SHORT)", + }, + ) + self.validate_all( + "1Y", + write={ + "duckdb": "CAST(1 AS TINYINT)", + "presto": "CAST(1 AS TINYINT)", + "hive": "CAST(1 AS TINYINT)", + "spark": "CAST(1 AS BYTE)", + }, + ) + self.validate_all( + "1L", + write={ + "duckdb": "CAST(1 AS BIGINT)", + "presto": "CAST(1 AS BIGINT)", + "hive": "CAST(1 AS BIGINT)", + "spark": "CAST(1 AS LONG)", + }, + ) + self.validate_all( + "1.0bd", + write={ + "duckdb": "CAST(1.0 AS DECIMAL)", + "presto": "CAST(1.0 AS DECIMAL)", + "hive": "CAST(1.0 AS DECIMAL)", + "spark": "CAST(1.0 AS DECIMAL)", + }, + ) + self.validate_all( + "CAST(1 AS INT)", + read={ + "presto": "TRY_CAST(1 AS INT)", + }, + write={ + "duckdb": "TRY_CAST(1 AS INT)", + "presto": "TRY_CAST(1 AS INTEGER)", + "hive": "CAST(1 AS INT)", + "spark": "CAST(1 AS INT)", + }, + ) + + def test_ddl(self): + self.validate_all( + "CREATE TABLE test STORED AS parquet TBLPROPERTIES ('x' = '1', 'Z' = '2') AS SELECT 1", + write={ + "presto": "CREATE TABLE test WITH (FORMAT = 'parquet', x = '1', Z = '2') AS SELECT 1", + "hive": "CREATE TABLE test STORED AS PARQUET TBLPROPERTIES ('x' = '1', 'Z' = '2') AS SELECT 1", + "spark": "CREATE TABLE test STORED AS PARQUET TBLPROPERTIES ('x' = '1', 'Z' = '2') AS SELECT 1", + }, + ) + self.validate_all( + "CREATE TABLE x (w STRING) PARTITIONED BY (y INT, z INT)", + write={ + "presto": "CREATE TABLE x (w VARCHAR, y INTEGER, z INTEGER) WITH (PARTITIONED_BY = ARRAY['y', 'z'])", + "hive": "CREATE TABLE x (w STRING) PARTITIONED BY (y INT, z INT)", + "spark": "CREATE TABLE x (w STRING) PARTITIONED BY (y INT, z INT)", + }, + ) + + def test_lateral_view(self): + self.validate_all( + "SELECT a, b FROM x LATERAL VIEW EXPLODE(y) t AS a LATERAL VIEW EXPLODE(z) u AS b", + write={ + "presto": "SELECT a, b FROM x CROSS JOIN UNNEST(y) AS t(a) CROSS JOIN UNNEST(z) AS u(b)", + "hive": "SELECT a, b FROM x LATERAL VIEW EXPLODE(y) t AS a LATERAL VIEW EXPLODE(z) u AS b", + "spark": "SELECT a, b FROM x LATERAL VIEW EXPLODE(y) t AS a LATERAL VIEW EXPLODE(z) u AS b", + }, + ) + self.validate_all( + "SELECT a FROM x LATERAL VIEW EXPLODE(y) t AS a", + write={ + "presto": "SELECT a FROM x CROSS JOIN UNNEST(y) AS t(a)", + "hive": "SELECT a FROM x LATERAL VIEW EXPLODE(y) t AS a", + "spark": "SELECT a FROM x LATERAL VIEW EXPLODE(y) t AS a", + }, + ) + self.validate_all( + "SELECT a FROM x LATERAL VIEW POSEXPLODE(y) t AS a", + write={ + "presto": "SELECT a FROM x CROSS JOIN UNNEST(y) WITH ORDINALITY AS t(a)", + "hive": "SELECT a FROM x LATERAL VIEW POSEXPLODE(y) t AS a", + "spark": "SELECT a FROM x LATERAL VIEW POSEXPLODE(y) t AS a", + }, + ) + self.validate_all( + "SELECT a FROM x LATERAL VIEW EXPLODE(ARRAY(y)) t AS a", + write={ + "presto": "SELECT a FROM x CROSS JOIN UNNEST(ARRAY[y]) AS t(a)", + "hive": "SELECT a FROM x LATERAL VIEW EXPLODE(ARRAY(y)) t AS a", + "spark": "SELECT a FROM x LATERAL VIEW EXPLODE(ARRAY(y)) t AS a", + }, + ) + + def test_quotes(self): + self.validate_all( + "'\\''", + write={ + "duckdb": "''''", + "presto": "''''", + "hive": "'\\''", + "spark": "'\\''", + }, + ) + self.validate_all( + "'\"x\"'", + write={ + "duckdb": "'\"x\"'", + "presto": "'\"x\"'", + "hive": "'\"x\"'", + "spark": "'\"x\"'", + }, + ) + self.validate_all( + "\"'x'\"", + write={ + "duckdb": "'''x'''", + "presto": "'''x'''", + "hive": "'\\'x\\''", + "spark": "'\\'x\\''", + }, + ) + self.validate_all( + "'\\\\a'", + read={ + "presto": "'\\a'", + }, + write={ + "duckdb": "'\\a'", + "presto": "'\\a'", + "hive": "'\\\\a'", + "spark": "'\\\\a'", + }, + ) + + def test_regex(self): + self.validate_all( + "a RLIKE 'x'", + write={ + "duckdb": "REGEXP_MATCHES(a, 'x')", + "presto": "REGEXP_LIKE(a, 'x')", + "hive": "a RLIKE 'x'", + "spark": "a RLIKE 'x'", + }, + ) + + self.validate_all( + "a REGEXP 'x'", + write={ + "duckdb": "REGEXP_MATCHES(a, 'x')", + "presto": "REGEXP_LIKE(a, 'x')", + "hive": "a RLIKE 'x'", + "spark": "a RLIKE 'x'", + }, + ) + + def test_time(self): + self.validate_all( + "DATEDIFF(a, b)", + write={ + "duckdb": "DATE_DIFF('day', CAST(b AS DATE), CAST(a AS DATE))", + "presto": "DATE_DIFF('day', CAST(SUBSTR(CAST(b AS VARCHAR), 1, 10) AS DATE), CAST(SUBSTR(CAST(a AS VARCHAR), 1, 10) AS DATE))", + "hive": "DATEDIFF(TO_DATE(a), TO_DATE(b))", + "spark": "DATEDIFF(TO_DATE(a), TO_DATE(b))", + "": "DATE_DIFF(TS_OR_DS_TO_DATE(a), TS_OR_DS_TO_DATE(b))", + }, + ) + self.validate_all( + """from_unixtime(x, "yyyy-MM-dd'T'HH")""", + write={ + "duckdb": "STRFTIME(TO_TIMESTAMP(CAST(x AS BIGINT)), '%Y-%m-%d''T''%H')", + "presto": "DATE_FORMAT(FROM_UNIXTIME(x), '%Y-%m-%d''T''%H')", + "hive": "FROM_UNIXTIME(x, 'yyyy-MM-dd\\'T\\'HH')", + "spark": "FROM_UNIXTIME(x, 'yyyy-MM-dd\\'T\\'HH')", + }, + ) + self.validate_all( + "DATE_FORMAT('2020-01-01', 'yyyy-MM-dd HH:mm:ss')", + write={ + "duckdb": "STRFTIME('2020-01-01', '%Y-%m-%d %H:%M:%S')", + "presto": "DATE_FORMAT('2020-01-01', '%Y-%m-%d %H:%i:%S')", + "hive": "DATE_FORMAT('2020-01-01', 'yyyy-MM-dd HH:mm:ss')", + "spark": "DATE_FORMAT('2020-01-01', 'yyyy-MM-dd HH:mm:ss')", + }, + ) + self.validate_all( + "DATE_ADD('2020-01-01', 1)", + write={ + "duckdb": "CAST('2020-01-01' AS DATE) + INTERVAL 1 DAY", + "presto": "DATE_ADD('DAY', 1, DATE_PARSE(SUBSTR('2020-01-01', 1, 10), '%Y-%m-%d'))", + "hive": "DATE_ADD('2020-01-01', 1)", + "spark": "DATE_ADD('2020-01-01', 1)", + "": "TS_OR_DS_ADD('2020-01-01', 1, 'DAY')", + }, + ) + self.validate_all( + "DATE_SUB('2020-01-01', 1)", + write={ + "duckdb": "CAST('2020-01-01' AS DATE) + INTERVAL 1 * -1 DAY", + "presto": "DATE_ADD('DAY', 1 * -1, DATE_PARSE(SUBSTR('2020-01-01', 1, 10), '%Y-%m-%d'))", + "hive": "DATE_ADD('2020-01-01', 1 * -1)", + "spark": "DATE_ADD('2020-01-01', 1 * -1)", + "": "TS_OR_DS_ADD('2020-01-01', 1 * -1, 'DAY')", + }, + ) + self.validate_all( + "DATEDIFF(TO_DATE(y), x)", + write={ + "duckdb": "DATE_DIFF('day', CAST(x AS DATE), CAST(CAST(y AS DATE) AS DATE))", + "presto": "DATE_DIFF('day', CAST(SUBSTR(CAST(x AS VARCHAR), 1, 10) AS DATE), CAST(SUBSTR(CAST(CAST(SUBSTR(CAST(y AS VARCHAR), 1, 10) AS DATE) AS VARCHAR), 1, 10) AS DATE))", + "hive": "DATEDIFF(TO_DATE(TO_DATE(y)), TO_DATE(x))", + "spark": "DATEDIFF(TO_DATE(TO_DATE(y)), TO_DATE(x))", + "": "DATE_DIFF(TS_OR_DS_TO_DATE(TS_OR_DS_TO_DATE(y)), TS_OR_DS_TO_DATE(x))", + }, + ) + self.validate_all( + "UNIX_TIMESTAMP(x)", + write={ + "duckdb": "EPOCH(STRPTIME(x, '%Y-%m-%d %H:%M:%S'))", + "presto": "TO_UNIXTIME(DATE_PARSE(x, '%Y-%m-%d %H:%i:%S'))", + "hive": "UNIX_TIMESTAMP(x)", + "spark": "UNIX_TIMESTAMP(x)", + "": "STR_TO_UNIX(x, '%Y-%m-%d %H:%M:%S')", + }, + ) + + for unit in ("DAY", "MONTH", "YEAR"): + self.validate_all( + f"{unit}(x)", + write={ + "duckdb": f"{unit}(CAST(x AS DATE))", + "presto": f"{unit}(CAST(SUBSTR(CAST(x AS VARCHAR), 1, 10) AS DATE))", + "hive": f"{unit}(TO_DATE(x))", + "spark": f"{unit}(TO_DATE(x))", + }, + ) + + def test_order_by(self): + self.validate_all( + "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname", + write={ + "duckdb": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname", + "presto": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname NULLS FIRST", + "hive": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname", + "spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname", + }, + ) + + def test_hive(self): + self.validate_all( + "PERCENTILE(x, 0.5)", + write={ + "duckdb": "QUANTILE(x, 0.5)", + "presto": "APPROX_PERCENTILE(x, 0.5)", + "hive": "PERCENTILE(x, 0.5)", + "spark": "PERCENTILE(x, 0.5)", + }, + ) + self.validate_all( + "APPROX_COUNT_DISTINCT(a)", + write={ + "duckdb": "APPROX_COUNT_DISTINCT(a)", + "presto": "APPROX_DISTINCT(a)", + "hive": "APPROX_COUNT_DISTINCT(a)", + "spark": "APPROX_COUNT_DISTINCT(a)", + }, + ) + self.validate_all( + "ARRAY_CONTAINS(x, 1)", + write={ + "duckdb": "ARRAY_CONTAINS(x, 1)", + "presto": "CONTAINS(x, 1)", + "hive": "ARRAY_CONTAINS(x, 1)", + "spark": "ARRAY_CONTAINS(x, 1)", + }, + ) + self.validate_all( + "SIZE(x)", + write={ + "duckdb": "ARRAY_LENGTH(x)", + "presto": "CARDINALITY(x)", + "hive": "SIZE(x)", + "spark": "SIZE(x)", + }, + ) + self.validate_all( + "LOCATE('a', x)", + write={ + "duckdb": "STRPOS(x, 'a')", + "presto": "STRPOS(x, 'a')", + "hive": "LOCATE('a', x)", + "spark": "LOCATE('a', x)", + }, + ) + self.validate_all( + "LOCATE('a', x, 3)", + write={ + "duckdb": "STRPOS(SUBSTR(x, 3), 'a') + 3 - 1", + "presto": "STRPOS(SUBSTR(x, 3), 'a') + 3 - 1", + "hive": "LOCATE('a', x, 3)", + "spark": "LOCATE('a', x, 3)", + }, + ) + self.validate_all( + "INITCAP('new york')", + write={ + "duckdb": "INITCAP('new york')", + "presto": "REGEXP_REPLACE('new york', '(\w)(\w*)', x -> UPPER(x[1]) || LOWER(x[2]))", + "hive": "INITCAP('new york')", + "spark": "INITCAP('new york')", + }, + ) + self.validate_all( + "SELECT * FROM x TABLESAMPLE(10) y", + write={ + "presto": "SELECT * FROM x AS y TABLESAMPLE(10)", + "hive": "SELECT * FROM x TABLESAMPLE(10) AS y", + "spark": "SELECT * FROM x TABLESAMPLE(10) AS y", + }, + ) + self.validate_all( + "SELECT SORT_ARRAY(x)", + write={ + "duckdb": "SELECT ARRAY_SORT(x)", + "presto": "SELECT ARRAY_SORT(x)", + "hive": "SELECT SORT_ARRAY(x)", + "spark": "SELECT SORT_ARRAY(x)", + }, + ) + self.validate_all( + "SELECT SORT_ARRAY(x, FALSE)", + read={ + "duckdb": "SELECT ARRAY_REVERSE_SORT(x)", + "spark": "SELECT SORT_ARRAY(x, FALSE)", + }, + write={ + "duckdb": "SELECT ARRAY_REVERSE_SORT(x)", + "presto": "SELECT ARRAY_SORT(x, (a, b) -> CASE WHEN a < b THEN 1 WHEN a > b THEN -1 ELSE 0 END)", + "hive": "SELECT SORT_ARRAY(x, FALSE)", + "spark": "SELECT SORT_ARRAY(x, FALSE)", + }, + ) + self.validate_all( + "GET_JSON_OBJECT(x, '$.name')", + write={ + "presto": "JSON_EXTRACT_SCALAR(x, '$.name')", + "hive": "GET_JSON_OBJECT(x, '$.name')", + "spark": "GET_JSON_OBJECT(x, '$.name')", + }, + ) + self.validate_all( + "MAP(a, b, c, d)", + write={ + "duckdb": "MAP(LIST_VALUE(a, c), LIST_VALUE(b, d))", + "presto": "MAP(ARRAY[a, c], ARRAY[b, d])", + "hive": "MAP(a, b, c, d)", + "spark": "MAP_FROM_ARRAYS(ARRAY(a, c), ARRAY(b, d))", + }, + ) + self.validate_all( + "MAP(a, b)", + write={ + "duckdb": "MAP(LIST_VALUE(a), LIST_VALUE(b))", + "presto": "MAP(ARRAY[a], ARRAY[b])", + "hive": "MAP(a, b)", + "spark": "MAP_FROM_ARRAYS(ARRAY(a), ARRAY(b))", + }, + ) + self.validate_all( + "LOG(10)", + write={ + "duckdb": "LN(10)", + "presto": "LN(10)", + "hive": "LN(10)", + "spark": "LN(10)", + }, + ) + self.validate_all( + "LOG(10, 2)", + write={ + "duckdb": "LOG(10, 2)", + "presto": "LOG(10, 2)", + "hive": "LOG(10, 2)", + "spark": "LOG(10, 2)", + }, + ) + self.validate_all( + 'ds = "2020-01-01"', + write={ + "duckdb": "ds = '2020-01-01'", + "presto": "ds = '2020-01-01'", + "hive": "ds = '2020-01-01'", + "spark": "ds = '2020-01-01'", + }, + ) + self.validate_all( + "ds = \"1''2\"", + write={ + "duckdb": "ds = '1''''2'", + "presto": "ds = '1''''2'", + "hive": "ds = '1\\'\\'2'", + "spark": "ds = '1\\'\\'2'", + }, + ) + self.validate_all( + "x == 1", + write={ + "duckdb": "x = 1", + "presto": "x = 1", + "hive": "x = 1", + "spark": "x = 1", + }, + ) + self.validate_all( + "x div y", + write={ + "duckdb": "CAST(x / y AS INT)", + "presto": "CAST(x / y AS INTEGER)", + "hive": "CAST(x / y AS INT)", + "spark": "CAST(x / y AS INT)", + }, + ) + self.validate_all( + "COLLECT_LIST(x)", + read={ + "presto": "ARRAY_AGG(x)", + }, + write={ + "duckdb": "ARRAY_AGG(x)", + "presto": "ARRAY_AGG(x)", + "hive": "COLLECT_LIST(x)", + "spark": "COLLECT_LIST(x)", + }, + ) + self.validate_all( + "COLLECT_SET(x)", + read={ + "presto": "SET_AGG(x)", + }, + write={ + "presto": "SET_AGG(x)", + "hive": "COLLECT_SET(x)", + "spark": "COLLECT_SET(x)", + }, + ) + self.validate_all( + "SELECT * FROM x TABLESAMPLE(1) AS foo", + read={ + "presto": "SELECT * FROM x AS foo TABLESAMPLE(1)", + }, + write={ + "presto": "SELECT * FROM x AS foo TABLESAMPLE(1)", + "hive": "SELECT * FROM x TABLESAMPLE(1) AS foo", + "spark": "SELECT * FROM x TABLESAMPLE(1) AS foo", + }, + ) diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py new file mode 100644 index 0000000..ee0c5f5 --- /dev/null +++ b/tests/dialects/test_mysql.py @@ -0,0 +1,79 @@ +from tests.dialects.test_dialect import Validator + + +class TestMySQL(Validator): + dialect = "mysql" + + def test_ddl(self): + self.validate_all( + "CREATE TABLE z (a INT) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARACTER SET=utf8 COLLATE=utf8_bin COMMENT='x'", + write={ + "mysql": "CREATE TABLE z (a INT) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARACTER SET=utf8 COLLATE=utf8_bin COMMENT='x'", + "spark": "CREATE TABLE z (a INT) COMMENT 'x'", + }, + ) + + def test_identity(self): + self.validate_identity("SELECT CAST(`a`.`b` AS INT) FROM foo") + + def test_introducers(self): + self.validate_all( + "_utf8mb4 'hola'", + read={ + "mysql": "_utf8mb4'hola'", + }, + write={ + "mysql": "_utf8mb4 'hola'", + }, + ) + + def test_binary_literal(self): + self.validate_all( + "SELECT 0xCC", + write={ + "mysql": "SELECT b'11001100'", + "spark": "SELECT X'11001100'", + }, + ) + self.validate_all( + "SELECT 0xz", + write={ + "mysql": "SELECT `0xz`", + }, + ) + self.validate_all( + "SELECT 0XCC", + write={ + "mysql": "SELECT 0 AS XCC", + }, + ) + + def test_string_literals(self): + self.validate_all( + 'SELECT "2021-01-01" + INTERVAL 1 MONTH', + write={ + "mysql": "SELECT '2021-01-01' + INTERVAL 1 MONTH", + }, + ) + + def test_convert(self): + self.validate_all( + "CONVERT(x USING latin1)", + write={ + "mysql": "CAST(x AS CHAR CHARACTER SET latin1)", + }, + ) + self.validate_all( + "CAST(x AS CHAR CHARACTER SET latin1)", + write={ + "mysql": "CAST(x AS CHAR CHARACTER SET latin1)", + }, + ) + + def test_hash_comments(self): + self.validate_all( + "SELECT 1 # arbitrary content,,, until end-of-line", + write={ + "mysql": "SELECT 1", + }, + ) diff --git a/tests/dialects/test_postgres.py b/tests/dialects/test_postgres.py new file mode 100644 index 0000000..15dbfd0 --- /dev/null +++ b/tests/dialects/test_postgres.py @@ -0,0 +1,93 @@ +from sqlglot import ParseError, transpile +from tests.dialects.test_dialect import Validator + + +class TestPostgres(Validator): + dialect = "postgres" + + def test_ddl(self): + self.validate_all( + "CREATE TABLE products (product_no INT UNIQUE, name TEXT, price DECIMAL)", + write={ + "postgres": "CREATE TABLE products (product_no INT UNIQUE, name TEXT, price DECIMAL)" + }, + ) + self.validate_all( + "CREATE TABLE products (product_no INT CONSTRAINT must_be_different UNIQUE, name TEXT CONSTRAINT present NOT NULL, price DECIMAL)", + write={ + "postgres": "CREATE TABLE products (product_no INT CONSTRAINT must_be_different UNIQUE, name TEXT CONSTRAINT present NOT NULL, price DECIMAL)" + }, + ) + self.validate_all( + "CREATE TABLE products (product_no INT, name TEXT, price DECIMAL, UNIQUE (product_no, name))", + write={ + "postgres": "CREATE TABLE products (product_no INT, name TEXT, price DECIMAL, UNIQUE (product_no, name))" + }, + ) + self.validate_all( + "CREATE TABLE products (" + "product_no INT UNIQUE," + " name TEXT," + " price DECIMAL CHECK (price > 0)," + " discounted_price DECIMAL CONSTRAINT positive_discount CHECK (discounted_price > 0)," + " CHECK (product_no > 1)," + " CONSTRAINT valid_discount CHECK (price > discounted_price))", + write={ + "postgres": "CREATE TABLE products (" + "product_no INT UNIQUE," + " name TEXT," + " price DECIMAL CHECK (price > 0)," + " discounted_price DECIMAL CONSTRAINT positive_discount CHECK (discounted_price > 0)," + " CHECK (product_no > 1)," + " CONSTRAINT valid_discount CHECK (price > discounted_price))" + }, + ) + + with self.assertRaises(ParseError): + transpile( + "CREATE TABLE products (price DECIMAL CHECK price > 0)", read="postgres" + ) + with self.assertRaises(ParseError): + transpile( + "CREATE TABLE products (price DECIMAL, CHECK price > 1)", + read="postgres", + ) + + def test_postgres(self): + self.validate_all( + "CREATE TABLE x (a INT SERIAL)", + read={"sqlite": "CREATE TABLE x (a INTEGER AUTOINCREMENT)"}, + write={"sqlite": "CREATE TABLE x (a INTEGER AUTOINCREMENT)"}, + ) + self.validate_all( + "CREATE TABLE x (a UUID, b BYTEA)", + write={ + "presto": "CREATE TABLE x (a UUID, b VARBINARY)", + "hive": "CREATE TABLE x (a UUID, b BINARY)", + "spark": "CREATE TABLE x (a UUID, b BINARY)", + }, + ) + self.validate_all( + "SELECT SUM(x) OVER (PARTITION BY a ORDER BY d ROWS 1 PRECEDING)", + write={ + "postgres": "SELECT SUM(x) OVER (PARTITION BY a ORDER BY d ROWS BETWEEN 1 PRECEDING AND CURRENT ROW)", + }, + ) + self.validate_all( + "SELECT * FROM x FETCH 1 ROW", + write={ + "postgres": "SELECT * FROM x FETCH FIRST 1 ROWS ONLY", + "presto": "SELECT * FROM x FETCH FIRST 1 ROWS ONLY", + "hive": "SELECT * FROM x FETCH FIRST 1 ROWS ONLY", + "spark": "SELECT * FROM x FETCH FIRST 1 ROWS ONLY", + }, + ) + self.validate_all( + "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname", + write={ + "postgres": "SELECT fname, lname, age FROM person ORDER BY age DESC, fname, lname", + "presto": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname", + "hive": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname NULLS LAST", + "spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname NULLS LAST", + }, + ) diff --git a/tests/dialects/test_presto.py b/tests/dialects/test_presto.py new file mode 100644 index 0000000..eb9aa5c --- /dev/null +++ b/tests/dialects/test_presto.py @@ -0,0 +1,422 @@ +from sqlglot import UnsupportedError +from tests.dialects.test_dialect import Validator + + +class TestPresto(Validator): + dialect = "presto" + + def test_cast(self): + self.validate_all( + "CAST(a AS ARRAY(INT))", + write={ + "bigquery": "CAST(a AS ARRAY)", + "duckdb": "CAST(a AS ARRAY)", + "presto": "CAST(a AS ARRAY(INTEGER))", + "spark": "CAST(a AS ARRAY)", + }, + ) + self.validate_all( + "CAST(a AS VARCHAR)", + write={ + "bigquery": "CAST(a AS STRING)", + "duckdb": "CAST(a AS TEXT)", + "presto": "CAST(a AS VARCHAR)", + "spark": "CAST(a AS STRING)", + }, + ) + self.validate_all( + "CAST(ARRAY[1, 2] AS ARRAY(BIGINT))", + write={ + "bigquery": "CAST([1, 2] AS ARRAY)", + "duckdb": "CAST(LIST_VALUE(1, 2) AS ARRAY)", + "presto": "CAST(ARRAY[1, 2] AS ARRAY(BIGINT))", + "spark": "CAST(ARRAY(1, 2) AS ARRAY)", + }, + ) + self.validate_all( + "CAST(MAP(ARRAY[1], ARRAY[1]) AS MAP(INT,INT))", + write={ + "bigquery": "CAST(MAP([1], [1]) AS MAP)", + "duckdb": "CAST(MAP(LIST_VALUE(1), LIST_VALUE(1)) AS MAP)", + "presto": "CAST(MAP(ARRAY[1], ARRAY[1]) AS MAP(INTEGER, INTEGER))", + "hive": "CAST(MAP(1, 1) AS MAP)", + "spark": "CAST(MAP_FROM_ARRAYS(ARRAY(1), ARRAY(1)) AS MAP)", + }, + ) + self.validate_all( + "CAST(MAP(ARRAY['a','b','c'], ARRAY[ARRAY[1], ARRAY[2], ARRAY[3]]) AS MAP(VARCHAR, ARRAY(INT)))", + write={ + "bigquery": "CAST(MAP(['a', 'b', 'c'], [[1], [2], [3]]) AS MAP>)", + "duckdb": "CAST(MAP(LIST_VALUE('a', 'b', 'c'), LIST_VALUE(LIST_VALUE(1), LIST_VALUE(2), LIST_VALUE(3))) AS MAP>)", + "presto": "CAST(MAP(ARRAY['a', 'b', 'c'], ARRAY[ARRAY[1], ARRAY[2], ARRAY[3]]) AS MAP(VARCHAR, ARRAY(INTEGER)))", + "hive": "CAST(MAP('a', ARRAY(1), 'b', ARRAY(2), 'c', ARRAY(3)) AS MAP>)", + "spark": "CAST(MAP_FROM_ARRAYS(ARRAY('a', 'b', 'c'), ARRAY(ARRAY(1), ARRAY(2), ARRAY(3))) AS MAP>)", + }, + ) + self.validate_all( + "CAST(x AS TIMESTAMP(9) WITH TIME ZONE)", + write={ + "bigquery": "CAST(x AS TIMESTAMPTZ(9))", + "duckdb": "CAST(x AS TIMESTAMPTZ(9))", + "presto": "CAST(x AS TIMESTAMP(9) WITH TIME ZONE)", + "hive": "CAST(x AS TIMESTAMPTZ(9))", + "spark": "CAST(x AS TIMESTAMPTZ(9))", + }, + ) + + def test_regex(self): + self.validate_all( + "REGEXP_LIKE(a, 'x')", + write={ + "duckdb": "REGEXP_MATCHES(a, 'x')", + "presto": "REGEXP_LIKE(a, 'x')", + "hive": "a RLIKE 'x'", + "spark": "a RLIKE 'x'", + }, + ) + self.validate_all( + "SPLIT(x, 'a.')", + write={ + "duckdb": "STR_SPLIT(x, 'a.')", + "presto": "SPLIT(x, 'a.')", + "hive": "SPLIT(x, CONCAT('\\\\Q', 'a.'))", + "spark": "SPLIT(x, CONCAT('\\\\Q', 'a.'))", + }, + ) + self.validate_all( + "REGEXP_SPLIT(x, 'a.')", + write={ + "duckdb": "STR_SPLIT_REGEX(x, 'a.')", + "presto": "REGEXP_SPLIT(x, 'a.')", + "hive": "SPLIT(x, 'a.')", + "spark": "SPLIT(x, 'a.')", + }, + ) + self.validate_all( + "CARDINALITY(x)", + write={ + "duckdb": "ARRAY_LENGTH(x)", + "presto": "CARDINALITY(x)", + "hive": "SIZE(x)", + "spark": "SIZE(x)", + }, + ) + + def test_time(self): + self.validate_all( + "DATE_FORMAT(x, '%Y-%m-%d %H:%i:%S')", + write={ + "duckdb": "STRFTIME(x, '%Y-%m-%d %H:%M:%S')", + "presto": "DATE_FORMAT(x, '%Y-%m-%d %H:%i:%S')", + "hive": "DATE_FORMAT(x, 'yyyy-MM-dd HH:mm:ss')", + "spark": "DATE_FORMAT(x, 'yyyy-MM-dd HH:mm:ss')", + }, + ) + self.validate_all( + "DATE_PARSE(x, '%Y-%m-%d %H:%i:%S')", + write={ + "duckdb": "STRPTIME(x, '%Y-%m-%d %H:%M:%S')", + "presto": "DATE_PARSE(x, '%Y-%m-%d %H:%i:%S')", + "hive": "CAST(x AS TIMESTAMP)", + "spark": "TO_TIMESTAMP(x, 'yyyy-MM-dd HH:mm:ss')", + }, + ) + self.validate_all( + "DATE_PARSE(x, '%Y-%m-%d')", + write={ + "duckdb": "STRPTIME(x, '%Y-%m-%d')", + "presto": "DATE_PARSE(x, '%Y-%m-%d')", + "hive": "CAST(x AS TIMESTAMP)", + "spark": "TO_TIMESTAMP(x, 'yyyy-MM-dd')", + }, + ) + self.validate_all( + "DATE_PARSE(SUBSTR(x, 1, 10), '%Y-%m-%d')", + write={ + "duckdb": "STRPTIME(SUBSTR(x, 1, 10), '%Y-%m-%d')", + "presto": "DATE_PARSE(SUBSTR(x, 1, 10), '%Y-%m-%d')", + "hive": "CAST(SUBSTR(x, 1, 10) AS TIMESTAMP)", + "spark": "TO_TIMESTAMP(SUBSTR(x, 1, 10), 'yyyy-MM-dd')", + }, + ) + self.validate_all( + "FROM_UNIXTIME(x)", + write={ + "duckdb": "TO_TIMESTAMP(CAST(x AS BIGINT))", + "presto": "FROM_UNIXTIME(x)", + "hive": "FROM_UNIXTIME(x)", + "spark": "FROM_UNIXTIME(x)", + }, + ) + self.validate_all( + "TO_UNIXTIME(x)", + write={ + "duckdb": "EPOCH(x)", + "presto": "TO_UNIXTIME(x)", + "hive": "UNIX_TIMESTAMP(x)", + "spark": "UNIX_TIMESTAMP(x)", + }, + ) + self.validate_all( + "DATE_ADD('day', 1, x)", + write={ + "duckdb": "x + INTERVAL 1 day", + "presto": "DATE_ADD('day', 1, x)", + "hive": "DATE_ADD(x, 1)", + "spark": "DATE_ADD(x, 1)", + }, + ) + + def test_ddl(self): + self.validate_all( + "CREATE TABLE test WITH (FORMAT = 'PARQUET') AS SELECT 1", + write={ + "presto": "CREATE TABLE test WITH (FORMAT = 'PARQUET') AS SELECT 1", + "hive": "CREATE TABLE test STORED AS PARQUET AS SELECT 1", + "spark": "CREATE TABLE test STORED AS PARQUET AS SELECT 1", + }, + ) + self.validate_all( + "CREATE TABLE test WITH (FORMAT = 'PARQUET', X = '1', Z = '2') AS SELECT 1", + write={ + "presto": "CREATE TABLE test WITH (FORMAT = 'PARQUET', X = '1', Z = '2') AS SELECT 1", + "hive": "CREATE TABLE test STORED AS PARQUET TBLPROPERTIES ('X' = '1', 'Z' = '2') AS SELECT 1", + "spark": "CREATE TABLE test STORED AS PARQUET TBLPROPERTIES ('X' = '1', 'Z' = '2') AS SELECT 1", + }, + ) + self.validate_all( + "CREATE TABLE x (w VARCHAR, y INTEGER, z INTEGER) WITH (PARTITIONED_BY = ARRAY['y', 'z'])", + write={ + "presto": "CREATE TABLE x (w VARCHAR, y INTEGER, z INTEGER) WITH (PARTITIONED_BY = ARRAY['y', 'z'])", + "hive": "CREATE TABLE x (w STRING) PARTITIONED BY (y INT, z INT)", + "spark": "CREATE TABLE x (w STRING) PARTITIONED BY (y INT, z INT)", + }, + ) + self.validate_all( + "CREATE TABLE x WITH (bucket_by = ARRAY['y'], bucket_count = 64) AS SELECT 1 AS y", + write={ + "presto": "CREATE TABLE x WITH (bucket_by = ARRAY['y'], bucket_count = 64) AS SELECT 1 AS y", + "hive": "CREATE TABLE x TBLPROPERTIES ('bucket_by' = ARRAY('y'), 'bucket_count' = 64) AS SELECT 1 AS y", + "spark": "CREATE TABLE x TBLPROPERTIES ('bucket_by' = ARRAY('y'), 'bucket_count' = 64) AS SELECT 1 AS y", + }, + ) + self.validate_all( + "CREATE TABLE db.example_table (col_a ROW(struct_col_a INTEGER, struct_col_b VARCHAR))", + write={ + "presto": "CREATE TABLE db.example_table (col_a ROW(struct_col_a INTEGER, struct_col_b VARCHAR))", + "hive": "CREATE TABLE db.example_table (col_a STRUCT)", + "spark": "CREATE TABLE db.example_table (col_a STRUCT)", + }, + ) + self.validate_all( + "CREATE TABLE db.example_table (col_a ROW(struct_col_a INTEGER, struct_col_b ROW(nested_col_a VARCHAR, nested_col_b VARCHAR)))", + write={ + "presto": "CREATE TABLE db.example_table (col_a ROW(struct_col_a INTEGER, struct_col_b ROW(nested_col_a VARCHAR, nested_col_b VARCHAR)))", + "hive": "CREATE TABLE db.example_table (col_a STRUCT>)", + "spark": "CREATE TABLE db.example_table (col_a STRUCT>)", + }, + ) + + self.validate( + "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname", + "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname", + read="presto", + write="presto", + ) + + def test_quotes(self): + self.validate_all( + "''''", + write={ + "duckdb": "''''", + "presto": "''''", + "hive": "'\\''", + "spark": "'\\''", + }, + ) + self.validate_all( + "'x'", + write={ + "duckdb": "'x'", + "presto": "'x'", + "hive": "'x'", + "spark": "'x'", + }, + ) + self.validate_all( + "'''x'''", + write={ + "duckdb": "'''x'''", + "presto": "'''x'''", + "hive": "'\\'x\\''", + "spark": "'\\'x\\''", + }, + ) + self.validate_all( + "'''x'", + write={ + "duckdb": "'''x'", + "presto": "'''x'", + "hive": "'\\'x'", + "spark": "'\\'x'", + }, + ) + self.validate_all( + "x IN ('a', 'a''b')", + write={ + "duckdb": "x IN ('a', 'a''b')", + "presto": "x IN ('a', 'a''b')", + "hive": "x IN ('a', 'a\\'b')", + "spark": "x IN ('a', 'a\\'b')", + }, + ) + + def test_unnest(self): + self.validate_all( + "SELECT a FROM x CROSS JOIN UNNEST(ARRAY(y)) AS t (a)", + write={ + "presto": "SELECT a FROM x CROSS JOIN UNNEST(ARRAY[y]) AS t(a)", + "hive": "SELECT a FROM x LATERAL VIEW EXPLODE(ARRAY(y)) t AS a", + "spark": "SELECT a FROM x LATERAL VIEW EXPLODE(ARRAY(y)) t AS a", + }, + ) + + def test_presto(self): + self.validate_all( + 'SELECT a."b" FROM "foo"', + write={ + "duckdb": 'SELECT a."b" FROM "foo"', + "presto": 'SELECT a."b" FROM "foo"', + "spark": "SELECT a.`b` FROM `foo`", + }, + ) + self.validate_all( + "SELECT ARRAY[1, 2]", + write={ + "bigquery": "SELECT [1, 2]", + "duckdb": "SELECT LIST_VALUE(1, 2)", + "presto": "SELECT ARRAY[1, 2]", + "spark": "SELECT ARRAY(1, 2)", + }, + ) + self.validate_all( + "SELECT APPROX_DISTINCT(a) FROM foo", + write={ + "duckdb": "SELECT APPROX_COUNT_DISTINCT(a) FROM foo", + "presto": "SELECT APPROX_DISTINCT(a) FROM foo", + "hive": "SELECT APPROX_COUNT_DISTINCT(a) FROM foo", + "spark": "SELECT APPROX_COUNT_DISTINCT(a) FROM foo", + }, + ) + self.validate_all( + "SELECT APPROX_DISTINCT(a, 0.1) FROM foo", + write={ + "duckdb": "SELECT APPROX_COUNT_DISTINCT(a) FROM foo", + "presto": "SELECT APPROX_DISTINCT(a, 0.1) FROM foo", + "hive": "SELECT APPROX_COUNT_DISTINCT(a) FROM foo", + "spark": "SELECT APPROX_COUNT_DISTINCT(a) FROM foo", + }, + ) + self.validate_all( + "SELECT APPROX_DISTINCT(a, 0.1) FROM foo", + write={ + "presto": "SELECT APPROX_DISTINCT(a, 0.1) FROM foo", + "hive": UnsupportedError, + "spark": UnsupportedError, + }, + ) + self.validate_all( + "SELECT JSON_EXTRACT(x, '$.name')", + write={ + "presto": "SELECT JSON_EXTRACT(x, '$.name')", + "hive": "SELECT GET_JSON_OBJECT(x, '$.name')", + "spark": "SELECT GET_JSON_OBJECT(x, '$.name')", + }, + ) + self.validate_all( + "SELECT JSON_EXTRACT_SCALAR(x, '$.name')", + write={ + "presto": "SELECT JSON_EXTRACT_SCALAR(x, '$.name')", + "hive": "SELECT GET_JSON_OBJECT(x, '$.name')", + "spark": "SELECT GET_JSON_OBJECT(x, '$.name')", + }, + ) + self.validate_all( + "'\u6bdb'", + write={ + "presto": "'\u6bdb'", + "hive": "'\u6bdb'", + "spark": "'\u6bdb'", + }, + ) + self.validate_all( + "SELECT ARRAY_SORT(x, (left, right) -> -1)", + write={ + "duckdb": "SELECT ARRAY_SORT(x)", + "presto": "SELECT ARRAY_SORT(x, (left, right) -> -1)", + "hive": "SELECT SORT_ARRAY(x)", + "spark": "SELECT ARRAY_SORT(x, (left, right) -> -1)", + }, + ) + self.validate_all( + "SELECT ARRAY_SORT(x)", + write={ + "presto": "SELECT ARRAY_SORT(x)", + "hive": "SELECT SORT_ARRAY(x)", + "spark": "SELECT ARRAY_SORT(x)", + }, + ) + self.validate_all( + "SELECT ARRAY_SORT(x, (left, right) -> -1)", + write={ + "hive": UnsupportedError, + }, + ) + self.validate_all( + "MAP(a, b)", + write={ + "hive": UnsupportedError, + "spark": "MAP_FROM_ARRAYS(a, b)", + }, + ) + self.validate_all( + "MAP(ARRAY(a, b), ARRAY(c, d))", + write={ + "hive": "MAP(a, c, b, d)", + "presto": "MAP(ARRAY[a, b], ARRAY[c, d])", + "spark": "MAP_FROM_ARRAYS(ARRAY(a, b), ARRAY(c, d))", + }, + ) + self.validate_all( + "MAP(ARRAY('a'), ARRAY('b'))", + write={ + "hive": "MAP('a', 'b')", + "presto": "MAP(ARRAY['a'], ARRAY['b'])", + "spark": "MAP_FROM_ARRAYS(ARRAY('a'), ARRAY('b'))", + }, + ) + self.validate_all( + "SELECT * FROM UNNEST(ARRAY['7', '14']) AS x", + write={ + "bigquery": "SELECT * FROM UNNEST(['7', '14'])", + "presto": "SELECT * FROM UNNEST(ARRAY['7', '14']) AS x", + "hive": "SELECT * FROM UNNEST(ARRAY('7', '14')) AS x", + "spark": "SELECT * FROM UNNEST(ARRAY('7', '14')) AS x", + }, + ) + self.validate_all( + "SELECT * FROM UNNEST(ARRAY['7', '14']) AS x(y)", + write={ + "bigquery": "SELECT * FROM UNNEST(['7', '14']) AS y", + "presto": "SELECT * FROM UNNEST(ARRAY['7', '14']) AS x(y)", + "hive": "SELECT * FROM UNNEST(ARRAY('7', '14')) AS x(y)", + "spark": "SELECT * FROM UNNEST(ARRAY('7', '14')) AS x(y)", + }, + ) + self.validate_all( + "WITH RECURSIVE t(n) AS (VALUES (1) UNION ALL SELECT n+1 FROM t WHERE n < 100 ) SELECT sum(n) FROM t", + write={ + "presto": "WITH RECURSIVE t(n) AS (VALUES (1) UNION ALL SELECT n + 1 FROM t WHERE n < 100) SELECT SUM(n) FROM t", + "spark": UnsupportedError, + }, + ) diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py new file mode 100644 index 0000000..62f78e1 --- /dev/null +++ b/tests/dialects/test_snowflake.py @@ -0,0 +1,145 @@ +from sqlglot import UnsupportedError +from tests.dialects.test_dialect import Validator + + +class TestSnowflake(Validator): + dialect = "snowflake" + + def test_snowflake(self): + self.validate_all( + 'x:a:"b c"', + write={ + "duckdb": "x['a']['b c']", + "hive": "x['a']['b c']", + "presto": "x['a']['b c']", + "snowflake": "x['a']['b c']", + "spark": "x['a']['b c']", + }, + ) + self.validate_all( + "SELECT a FROM test WHERE a = 1 GROUP BY a HAVING a = 2 QUALIFY z ORDER BY a LIMIT 10", + write={ + "bigquery": "SELECT a FROM test WHERE a = 1 GROUP BY a HAVING a = 2 QUALIFY z ORDER BY a NULLS LAST LIMIT 10", + "snowflake": "SELECT a FROM test WHERE a = 1 GROUP BY a HAVING a = 2 QUALIFY z ORDER BY a LIMIT 10", + }, + ) + self.validate_all( + "SELECT a FROM test AS t QUALIFY ROW_NUMBER() OVER (PARTITION BY a ORDER BY Z) = 1", + write={ + "bigquery": "SELECT a FROM test AS t QUALIFY ROW_NUMBER() OVER (PARTITION BY a ORDER BY Z NULLS LAST) = 1", + "snowflake": "SELECT a FROM test AS t QUALIFY ROW_NUMBER() OVER (PARTITION BY a ORDER BY Z) = 1", + }, + ) + self.validate_all( + "SELECT TO_TIMESTAMP(1659981729)", + write={ + "bigquery": "SELECT UNIX_TO_TIME(1659981729)", + "snowflake": "SELECT TO_TIMESTAMP(1659981729)", + "spark": "SELECT FROM_UNIXTIME(1659981729)", + }, + ) + self.validate_all( + "SELECT TO_TIMESTAMP(1659981729000, 3)", + write={ + "bigquery": "SELECT UNIX_TO_TIME(1659981729000, 'millis')", + "snowflake": "SELECT TO_TIMESTAMP(1659981729000, 3)", + "spark": "SELECT TIMESTAMP_MILLIS(1659981729000)", + }, + ) + self.validate_all( + "SELECT TO_TIMESTAMP('1659981729')", + write={ + "bigquery": "SELECT UNIX_TO_TIME('1659981729')", + "snowflake": "SELECT TO_TIMESTAMP('1659981729')", + "spark": "SELECT FROM_UNIXTIME('1659981729')", + }, + ) + self.validate_all( + "SELECT TO_TIMESTAMP(1659981729000000000, 9)", + write={ + "bigquery": "SELECT UNIX_TO_TIME(1659981729000000000, 'micros')", + "snowflake": "SELECT TO_TIMESTAMP(1659981729000000000, 9)", + "spark": "SELECT TIMESTAMP_MICROS(1659981729000000000)", + }, + ) + self.validate_all( + "SELECT TO_TIMESTAMP('2013-04-05 01:02:03')", + write={ + "bigquery": "SELECT STR_TO_TIME('2013-04-05 01:02:03', '%Y-%m-%d %H:%M:%S')", + "snowflake": "SELECT TO_TIMESTAMP('2013-04-05 01:02:03', 'yyyy-mm-dd hh24:mi:ss')", + "spark": "SELECT TO_TIMESTAMP('2013-04-05 01:02:03', 'yyyy-MM-dd HH:mm:ss')", + }, + ) + self.validate_all( + "SELECT TO_TIMESTAMP('04/05/2013 01:02:03', 'mm/dd/yyyy hh24:mi:ss')", + read={ + "bigquery": "SELECT STR_TO_TIME('04/05/2013 01:02:03', '%m/%d/%Y %H:%M:%S')", + "duckdb": "SELECT STRPTIME('04/05/2013 01:02:03', '%m/%d/%Y %H:%M:%S')", + "snowflake": "SELECT TO_TIMESTAMP('04/05/2013 01:02:03', 'mm/dd/yyyy hh24:mi:ss')", + }, + write={ + "bigquery": "SELECT STR_TO_TIME('04/05/2013 01:02:03', '%m/%d/%Y %H:%M:%S')", + "snowflake": "SELECT TO_TIMESTAMP('04/05/2013 01:02:03', 'mm/dd/yyyy hh24:mi:ss')", + "spark": "SELECT TO_TIMESTAMP('04/05/2013 01:02:03', 'MM/dd/yyyy HH:mm:ss')", + }, + ) + self.validate_all( + "SELECT IFF(TRUE, 'true', 'false')", + write={ + "snowflake": "SELECT IFF(TRUE, 'true', 'false')", + }, + ) + self.validate_all( + "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname", + write={ + "duckdb": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname NULLS LAST", + "postgres": "SELECT fname, lname, age FROM person ORDER BY age DESC, fname, lname", + "presto": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname", + "hive": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname NULLS LAST", + "spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname NULLS LAST", + "snowflake": "SELECT fname, lname, age FROM person ORDER BY age DESC, fname, lname", + }, + ) + self.validate_all( + "SELECT ARRAY_AGG(DISTINCT a)", + write={ + "spark": "SELECT COLLECT_LIST(DISTINCT a)", + "snowflake": "SELECT ARRAY_AGG(DISTINCT a)", + }, + ) + self.validate_all( + "SELECT * FROM a INTERSECT ALL SELECT * FROM b", + write={ + "snowflake": UnsupportedError, + }, + ) + self.validate_all( + "SELECT * FROM a EXCEPT ALL SELECT * FROM b", + write={ + "snowflake": UnsupportedError, + }, + ) + self.validate_all( + "SELECT ARRAY_UNION_AGG(a)", + write={ + "snowflake": "SELECT ARRAY_UNION_AGG(a)", + }, + ) + self.validate_all( + "SELECT NVL2(a, b, c)", + write={ + "snowflake": "SELECT NVL2(a, b, c)", + }, + ) + self.validate_all( + "SELECT $$a$$", + write={ + "snowflake": "SELECT 'a'", + }, + ) + self.validate_all( + r"SELECT $$a ' \ \t \x21 z $ $$", + write={ + "snowflake": r"SELECT 'a \' \\ \\t \\x21 z $ '", + }, + ) diff --git a/tests/dialects/test_spark.py b/tests/dialects/test_spark.py new file mode 100644 index 0000000..8794fed --- /dev/null +++ b/tests/dialects/test_spark.py @@ -0,0 +1,226 @@ +from tests.dialects.test_dialect import Validator + + +class TestSpark(Validator): + dialect = "spark" + + def test_ddl(self): + self.validate_all( + "CREATE TABLE db.example_table (col_a struct)", + write={ + "presto": "CREATE TABLE db.example_table (col_a ROW(struct_col_a INTEGER, struct_col_b VARCHAR))", + "hive": "CREATE TABLE db.example_table (col_a STRUCT)", + "spark": "CREATE TABLE db.example_table (col_a STRUCT)", + }, + ) + self.validate_all( + "CREATE TABLE db.example_table (col_a struct>)", + write={ + "bigquery": "CREATE TABLE db.example_table (col_a STRUCT>)", + "presto": "CREATE TABLE db.example_table (col_a ROW(struct_col_a INTEGER, struct_col_b ROW(nested_col_a VARCHAR, nested_col_b VARCHAR)))", + "hive": "CREATE TABLE db.example_table (col_a STRUCT>)", + "spark": "CREATE TABLE db.example_table (col_a STRUCT>)", + }, + ) + self.validate_all( + "CREATE TABLE db.example_table (col_a array, col_b array>)", + write={ + "bigquery": "CREATE TABLE db.example_table (col_a ARRAY, col_b ARRAY>)", + "presto": "CREATE TABLE db.example_table (col_a ARRAY(INTEGER), col_b ARRAY(ARRAY(INTEGER)))", + "hive": "CREATE TABLE db.example_table (col_a ARRAY, col_b ARRAY>)", + "spark": "CREATE TABLE db.example_table (col_a ARRAY, col_b ARRAY>)", + }, + ) + self.validate_all( + "CREATE TABLE x USING ICEBERG PARTITIONED BY (MONTHS(y)) LOCATION 's3://z'", + write={ + "presto": "CREATE TABLE x WITH (TABLE_FORMAT = 'ICEBERG', PARTITIONED_BY = ARRAY['MONTHS'])", + "hive": "CREATE TABLE x USING ICEBERG PARTITIONED BY (MONTHS(y)) LOCATION 's3://z'", + "spark": "CREATE TABLE x USING ICEBERG PARTITIONED BY (MONTHS(y)) LOCATION 's3://z'", + }, + ) + self.validate_all( + "CREATE TABLE test STORED AS PARQUET AS SELECT 1", + write={ + "presto": "CREATE TABLE test WITH (FORMAT = 'PARQUET') AS SELECT 1", + "hive": "CREATE TABLE test STORED AS PARQUET AS SELECT 1", + "spark": "CREATE TABLE test STORED AS PARQUET AS SELECT 1", + }, + ) + self.validate_all( + "CREATE TABLE test USING ICEBERG STORED AS PARQUET AS SELECT 1", + write={ + "presto": "CREATE TABLE test WITH (TABLE_FORMAT = 'ICEBERG', FORMAT = 'PARQUET') AS SELECT 1", + "hive": "CREATE TABLE test USING ICEBERG STORED AS PARQUET AS SELECT 1", + "spark": "CREATE TABLE test USING ICEBERG STORED AS PARQUET AS SELECT 1", + }, + ) + self.validate_all( + """CREATE TABLE blah (col_a INT) COMMENT "Test comment: blah" PARTITIONED BY (date STRING) STORED AS ICEBERG TBLPROPERTIES('x' = '1')""", + write={ + "presto": """CREATE TABLE blah ( + col_a INTEGER, + date VARCHAR +) +COMMENT='Test comment: blah' +WITH ( + PARTITIONED_BY = ARRAY['date'], + FORMAT = 'ICEBERG', + x = '1' +)""", + "hive": """CREATE TABLE blah ( + col_a INT +) +COMMENT 'Test comment: blah' +PARTITIONED BY ( + date STRING +) +STORED AS ICEBERG +TBLPROPERTIES ( + 'x' = '1' +)""", + "spark": """CREATE TABLE blah ( + col_a INT +) +COMMENT 'Test comment: blah' +PARTITIONED BY ( + date STRING +) +STORED AS ICEBERG +TBLPROPERTIES ( + 'x' = '1' +)""", + }, + pretty=True, + ) + + def test_to_date(self): + self.validate_all( + "TO_DATE(x, 'yyyy-MM-dd')", + write={ + "duckdb": "CAST(x AS DATE)", + "hive": "TO_DATE(x)", + "presto": "CAST(SUBSTR(CAST(x AS VARCHAR), 1, 10) AS DATE)", + "spark": "TO_DATE(x)", + }, + ) + self.validate_all( + "TO_DATE(x, 'yyyy')", + write={ + "duckdb": "CAST(STRPTIME(x, '%Y') AS DATE)", + "hive": "TO_DATE(x, 'yyyy')", + "presto": "CAST(DATE_PARSE(x, '%Y') AS DATE)", + "spark": "TO_DATE(x, 'yyyy')", + }, + ) + + def test_hint(self): + self.validate_all( + "SELECT /*+ COALESCE(3) */ * FROM x", + write={ + "spark": "SELECT /*+ COALESCE(3) */ * FROM x", + }, + ) + self.validate_all( + "SELECT /*+ COALESCE(3), REPARTITION(1) */ * FROM x", + write={ + "spark": "SELECT /*+ COALESCE(3), REPARTITION(1) */ * FROM x", + }, + ) + + def test_spark(self): + self.validate_all( + "ARRAY_SORT(x, (left, right) -> -1)", + write={ + "duckdb": "ARRAY_SORT(x)", + "presto": "ARRAY_SORT(x, (left, right) -> -1)", + "hive": "SORT_ARRAY(x)", + "spark": "ARRAY_SORT(x, (left, right) -> -1)", + }, + ) + self.validate_all( + "ARRAY(0, 1, 2)", + write={ + "bigquery": "[0, 1, 2]", + "duckdb": "LIST_VALUE(0, 1, 2)", + "presto": "ARRAY[0, 1, 2]", + "hive": "ARRAY(0, 1, 2)", + "spark": "ARRAY(0, 1, 2)", + }, + ) + + self.validate_all( + "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname", + write={ + "clickhouse": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname NULLS FIRST", + "duckdb": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname", + "postgres": "SELECT fname, lname, age FROM person ORDER BY age DESC, fname, lname NULLS FIRST", + "presto": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname NULLS FIRST", + "hive": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname", + "spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname", + "snowflake": "SELECT fname, lname, age FROM person ORDER BY age DESC, fname, lname NULLS FIRST", + }, + ) + self.validate_all( + "SELECT APPROX_COUNT_DISTINCT(a) FROM foo", + write={ + "duckdb": "SELECT APPROX_COUNT_DISTINCT(a) FROM foo", + "presto": "SELECT APPROX_DISTINCT(a) FROM foo", + "hive": "SELECT APPROX_COUNT_DISTINCT(a) FROM foo", + "spark": "SELECT APPROX_COUNT_DISTINCT(a) FROM foo", + }, + ) + self.validate_all( + "MONTH('2021-03-01')", + write={ + "duckdb": "MONTH(CAST('2021-03-01' AS DATE))", + "presto": "MONTH(CAST(SUBSTR(CAST('2021-03-01' AS VARCHAR), 1, 10) AS DATE))", + "hive": "MONTH(TO_DATE('2021-03-01'))", + "spark": "MONTH(TO_DATE('2021-03-01'))", + }, + ) + self.validate_all( + "YEAR('2021-03-01')", + write={ + "duckdb": "YEAR(CAST('2021-03-01' AS DATE))", + "presto": "YEAR(CAST(SUBSTR(CAST('2021-03-01' AS VARCHAR), 1, 10) AS DATE))", + "hive": "YEAR(TO_DATE('2021-03-01'))", + "spark": "YEAR(TO_DATE('2021-03-01'))", + }, + ) + self.validate_all( + "'\u6bdb'", + write={ + "duckdb": "'毛'", + "presto": "'毛'", + "hive": "'毛'", + "spark": "'毛'", + }, + ) + self.validate_all( + "SELECT LEFT(x, 2), RIGHT(x, 2)", + write={ + "duckdb": "SELECT SUBSTRING(x, 1, 2), SUBSTRING(x, LENGTH(x) - 2 + 1, 2)", + "presto": "SELECT SUBSTRING(x, 1, 2), SUBSTRING(x, LENGTH(x) - 2 + 1, 2)", + "hive": "SELECT SUBSTRING(x, 1, 2), SUBSTRING(x, LENGTH(x) - 2 + 1, 2)", + "spark": "SELECT SUBSTRING(x, 1, 2), SUBSTRING(x, LENGTH(x) - 2 + 1, 2)", + }, + ) + self.validate_all( + "MAP_FROM_ARRAYS(ARRAY(1), c)", + write={ + "duckdb": "MAP(LIST_VALUE(1), c)", + "presto": "MAP(ARRAY[1], c)", + "hive": "MAP(ARRAY(1), c)", + "spark": "MAP_FROM_ARRAYS(ARRAY(1), c)", + }, + ) + self.validate_all( + "SELECT ARRAY_SORT(x)", + write={ + "duckdb": "SELECT ARRAY_SORT(x)", + "presto": "SELECT ARRAY_SORT(x)", + "hive": "SELECT SORT_ARRAY(x)", + "spark": "SELECT ARRAY_SORT(x)", + }, + ) diff --git a/tests/dialects/test_sqlite.py b/tests/dialects/test_sqlite.py new file mode 100644 index 0000000..a0576de --- /dev/null +++ b/tests/dialects/test_sqlite.py @@ -0,0 +1,72 @@ +from tests.dialects.test_dialect import Validator + + +class TestSQLite(Validator): + dialect = "sqlite" + + def test_ddl(self): + self.validate_all( + """ + CREATE TABLE "Track" + ( + CONSTRAINT "PK_Track" FOREIGN KEY ("TrackId"), + FOREIGN KEY ("AlbumId") REFERENCES "Album" ("AlbumId") + ON DELETE NO ACTION ON UPDATE NO ACTION, + FOREIGN KEY ("AlbumId") ON DELETE CASCADE ON UPDATE RESTRICT, + FOREIGN KEY ("AlbumId") ON DELETE SET NULL ON UPDATE SET DEFAULT + ) + """, + write={ + "sqlite": """CREATE TABLE "Track" ( + CONSTRAINT "PK_Track" FOREIGN KEY ("TrackId"), + FOREIGN KEY ("AlbumId") REFERENCES "Album"("AlbumId") ON DELETE NO ACTION ON UPDATE NO ACTION, + FOREIGN KEY ("AlbumId") ON DELETE CASCADE ON UPDATE RESTRICT, + FOREIGN KEY ("AlbumId") ON DELETE SET NULL ON UPDATE SET DEFAULT +)""", + }, + pretty=True, + ) + self.validate_all( + "CREATE TABLE z (a INTEGER UNIQUE PRIMARY KEY AUTOINCREMENT)", + read={ + "mysql": "CREATE TABLE z (a INT UNIQUE PRIMARY KEY AUTO_INCREMENT)", + }, + write={ + "sqlite": "CREATE TABLE z (a INTEGER UNIQUE PRIMARY KEY AUTOINCREMENT)", + "mysql": "CREATE TABLE z (a INT UNIQUE PRIMARY KEY AUTO_INCREMENT)", + }, + ) + self.validate_all( + """CREATE TABLE "x" ("Name" NVARCHAR(200) NOT NULL)""", + write={ + "sqlite": """CREATE TABLE "x" ("Name" TEXT(200) NOT NULL)""", + "mysql": "CREATE TABLE `x` (`Name` VARCHAR(200) NOT NULL)", + }, + ) + + def test_sqlite(self): + self.validate_all( + "SELECT CAST([a].[b] AS SMALLINT) FROM foo", + write={ + "sqlite": 'SELECT CAST("a"."b" AS INTEGER) FROM foo', + "spark": "SELECT CAST(`a`.`b` AS SHORT) FROM foo", + }, + ) + self.validate_all( + "EDITDIST3(col1, col2)", + read={ + "sqlite": "EDITDIST3(col1, col2)", + "spark": "LEVENSHTEIN(col1, col2)", + }, + write={ + "sqlite": "EDITDIST3(col1, col2)", + "spark": "LEVENSHTEIN(col1, col2)", + }, + ) + self.validate_all( + "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname", + write={ + "spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname", + "sqlite": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname", + }, + ) diff --git a/tests/dialects/test_starrocks.py b/tests/dialects/test_starrocks.py new file mode 100644 index 0000000..1fe1a57 --- /dev/null +++ b/tests/dialects/test_starrocks.py @@ -0,0 +1,8 @@ +from tests.dialects.test_dialect import Validator + + +class TestMySQL(Validator): + dialect = "starrocks" + + def test_identity(self): + self.validate_identity("SELECT CAST(`a`.`b` AS INT) FROM foo") diff --git a/tests/dialects/test_tableau.py b/tests/dialects/test_tableau.py new file mode 100644 index 0000000..0f612dd --- /dev/null +++ b/tests/dialects/test_tableau.py @@ -0,0 +1,62 @@ +from tests.dialects.test_dialect import Validator + + +class TestTableau(Validator): + dialect = "tableau" + + def test_tableau(self): + self.validate_all( + "IF x = 'a' THEN y ELSE NULL END", + read={ + "presto": "IF(x = 'a', y, NULL)", + }, + write={ + "presto": "IF(x = 'a', y, NULL)", + "hive": "IF(x = 'a', y, NULL)", + "tableau": "IF x = 'a' THEN y ELSE NULL END", + }, + ) + self.validate_all( + "IFNULL(a, 0)", + read={ + "presto": "COALESCE(a, 0)", + }, + write={ + "presto": "COALESCE(a, 0)", + "hive": "COALESCE(a, 0)", + "tableau": "IFNULL(a, 0)", + }, + ) + self.validate_all( + "COUNTD(a)", + read={ + "presto": "COUNT(DISTINCT a)", + }, + write={ + "presto": "COUNT(DISTINCT a)", + "hive": "COUNT(DISTINCT a)", + "tableau": "COUNTD(a)", + }, + ) + self.validate_all( + "COUNTD((a))", + read={ + "presto": "COUNT(DISTINCT(a))", + }, + write={ + "presto": "COUNT(DISTINCT (a))", + "hive": "COUNT(DISTINCT (a))", + "tableau": "COUNTD((a))", + }, + ) + self.validate_all( + "COUNT(a)", + read={ + "presto": "COUNT(a)", + }, + write={ + "presto": "COUNT(a)", + "hive": "COUNT(a)", + "tableau": "COUNT(a)", + }, + ) diff --git a/tests/fixtures/identity.sql b/tests/fixtures/identity.sql new file mode 100644 index 0000000..40f11a2 --- /dev/null +++ b/tests/fixtures/identity.sql @@ -0,0 +1,514 @@ +SUM(1) +SUM(CASE WHEN x > 1 THEN 1 ELSE 0 END) / y +1 +1.0 +1E2 +1E+2 +1E-2 +1.1E10 +1.12e-10 +-11.023E7 * 3 +(1 * 2) / (3 - 5) +((TRUE)) +'' +'''' +'x' +'\x' +"x" +"" +x +x % 1 +x < 1 +x <= 1 +x > 1 +x >= 1 +x <> 1 +x = y OR x > 1 +x & 1 +x | 1 +x ^ 1 +~x +x << 1 +x >> 1 +x >> 1 | 1 & 1 ^ 1 +x || y +1 - -1 +dec.x + y +a.filter +a.b.c +a.b.c.d +a.b.c.d.e +a.b.c.d.e[0] +a.b.c.d.e[0].f +a[0][0].b.c[1].d.e.f[1][1] +a[0].b[1] +a[0].b.c['d'] +a.b.C() +a['x'].b.C() +a.B() +a['x'].C() +int.x +map.x +x IN (-1, 1) +x IN ('a', 'a''a') +x IN ((1)) +x BETWEEN -1 AND 1 +x BETWEEN 'a' || b AND 'c' || d +NOT x IS NULL +x IS TRUE +x IS FALSE +time +zone +ARRAY +CURRENT_DATE +CURRENT_DATE('UTC') +CURRENT_DATE AT TIME ZONE 'UTC' +CURRENT_DATE AT TIME ZONE zone_column +CURRENT_DATE AT TIME ZONE 'UTC' AT TIME ZONE 'Asia/Tokio' +ARRAY() +ARRAY(1, 2) +ARRAY_CONTAINS(x, 1) +EXTRACT(x FROM y) +EXTRACT(DATE FROM y) +CONCAT_WS('-', 'a', 'b') +CONCAT_WS('-', 'a', 'b', 'c') +POSEXPLODE("x") AS ("a", "b") +POSEXPLODE("x") AS ("a", "b", "c") +STR_POSITION(x, 'a') +STR_POSITION(x, 'a', 3) +SPLIT(SPLIT(referrer, 'utm_source=')[OFFSET(1)], "&")[OFFSET(0)] +x[ORDINAL(1)][SAFE_OFFSET(2)] +x LIKE SUBSTR('abc', 1, 1) +x LIKE y +x LIKE a.y +x LIKE '%y%' +x ILIKE '%y%' +x LIKE '%y%' ESCAPE '\' +x ILIKE '%y%' ESCAPE '\' +1 AS escape +INTERVAL '1' day +INTERVAL '1' month +INTERVAL '1 day' +INTERVAL 2 months +INTERVAL 1 + 3 days +TIMESTAMP_DIFF(CURRENT_TIMESTAMP(), 1, DAY) +DATETIME_DIFF(CURRENT_DATE, 1, DAY) +QUANTILE(x, 0.5) +REGEXP_REPLACE('new york', '(\w)(\w*)', x -> UPPER(x[1]) || LOWER(x[2])) +REGEXP_LIKE('new york', '.') +REGEXP_SPLIT('new york', '.') +SPLIT('new york', '.') +X((y AS z)).1 +(x AS y, y AS z) +REPLACE(1) +DATE(x) = DATE(y) +TIMESTAMP(DATE(x)) +TIMESTAMP_TRUNC(COALESCE(time_field, CURRENT_TIMESTAMP()), DAY) +COUNT(DISTINCT CASE WHEN DATE_TRUNC(DATE(time_field), isoweek) = DATE_TRUNC(DATE(time_field2), isoweek) THEN report_id ELSE NULL END) +x[y - 1] +CASE WHEN SUM(x) > 3 THEN 1 END OVER (PARTITION BY x) +SUM(ROW() OVER (PARTITION BY x)) +SUM(ROW() OVER (PARTITION BY x + 1)) +SUM(ROW() OVER (PARTITION BY x AND y)) +(ROW() OVER ()) +CASE WHEN (x > 1) THEN 1 ELSE 0 END +CASE (1) WHEN 1 THEN 1 ELSE 0 END +CASE 1 WHEN 1 THEN 1 ELSE 0 END +x AT TIME ZONE 'UTC' +CAST('2025-11-20 00:00:00+00' AS TIMESTAMP) AT TIME ZONE 'Africa/Cairo' +SET x = 1 +SET -v +ADD JAR s3://bucket +ADD JARS s3://bucket, c +ADD FILE s3://file +ADD FILES s3://file, s3://a +ADD ARCHIVE s3://file +ADD ARCHIVES s3://file, s3://a +BEGIN IMMEDIATE TRANSACTION +COMMIT +USE db +NOT 1 +NOT NOT 1 +SELECT * FROM test +SELECT *, 1 FROM test +SELECT * FROM a.b +SELECT * FROM a.b.c +SELECT * FROM table +SELECT 1 +SELECT 1 FROM test +SELECT * FROM a, b, (SELECT 1) AS c +SELECT a FROM test +SELECT 1 AS filter +SELECT SUM(x) AS filter +SELECT 1 AS range FROM test +SELECT 1 AS count FROM test +SELECT 1 AS comment FROM test +SELECT 1 AS numeric FROM test +SELECT 1 AS number FROM test +SELECT t.count +SELECT DISTINCT x FROM test +SELECT DISTINCT x, y FROM test +SELECT DISTINCT TIMESTAMP_TRUNC(time_field, MONTH) AS time_value FROM "table" +SELECT DISTINCT ON (x) x, y FROM z +SELECT DISTINCT ON (x, y + 1) * FROM z +SELECT DISTINCT ON (x.y) * FROM z +SELECT top.x +SELECT TIMESTAMP(DATE_TRUNC(DATE(time_field), MONTH)) AS time_value FROM "table" +SELECT GREATEST((3 + 1), LEAST(3, 4)) +SELECT TRANSFORM(a, b -> b) AS x +SELECT AGGREGATE(a, (a, b) -> a + b) AS x +SELECT SUM(DISTINCT x) +SELECT SUM(x IGNORE NULLS) AS x +SELECT ARRAY_AGG(DISTINCT x IGNORE NULLS ORDER BY a, b DESC LIMIT 10) AS x +SELECT ARRAY_AGG(STRUCT(x, x AS y) ORDER BY z DESC) AS x +SELECT LAST_VALUE(x IGNORE NULLS) OVER y AS x +SELECT LAG(x) OVER (ORDER BY y) AS x +SELECT LEAD(a) OVER (ORDER BY b) AS a +SELECT LEAD(a, 1) OVER (PARTITION BY a ORDER BY a) AS x +SELECT LEAD(a, 1, b) OVER (PARTITION BY a ORDER BY a) AS x +SELECT X((a, b) -> a + b, z -> z) AS x +SELECT X(a -> "a" + ("z" - 1)) +SELECT EXISTS(ARRAY(2, 3), x -> x % 2 = 0) +SELECT test.* FROM test +SELECT a AS b FROM test +SELECT "a"."b" FROM "a" +SELECT "a".b FROM a +SELECT a.b FROM "a" +SELECT a.b FROM a +SELECT '"hi' AS x FROM x +SELECT 1 AS "|sum" FROM x +SELECT '\"hi' AS x FROM x +SELECT 1 AS b FROM test +SELECT 1 AS "b" FROM test +SELECT 1 + 1 FROM test +SELECT 1 - 1 FROM test +SELECT 1 * 1 FROM test +SELECT 1 % 1 FROM test +SELECT 1 / 1 FROM test +SELECT 1 < 2 FROM test +SELECT 1 <= 2 FROM test +SELECT 1 > 2 FROM test +SELECT 1 >= 2 FROM test +SELECT 1 <> 2 FROM test +SELECT JSON_EXTRACT(x, '$.name') +SELECT JSON_EXTRACT_SCALAR(x, '$.name') +SELECT x LIKE '%x%' FROM test +SELECT * FROM test LIMIT 100 +SELECT * FROM test LIMIT 100 OFFSET 200 +SELECT * FROM test FETCH FIRST 1 ROWS ONLY +SELECT * FROM test FETCH NEXT 1 ROWS ONLY +SELECT (1 > 2) AS x FROM test +SELECT NOT (1 > 2) FROM test +SELECT 1 + 2 AS x FROM test +SELECT a, b, 1 < 1 FROM test +SELECT a FROM test WHERE NOT FALSE +SELECT a FROM test WHERE a = 1 +SELECT a FROM test WHERE a = 1 AND b = 2 +SELECT a FROM test WHERE a IN (SELECT b FROM z) +SELECT a FROM test WHERE a IN ((SELECT 1), 2) +SELECT * FROM x WHERE y IN ((SELECT 1) EXCEPT (SELECT 2)) +SELECT * FROM x WHERE y IN (SELECT 1 UNION SELECT 2) +SELECT * FROM x WHERE y IN ((SELECT 1 UNION SELECT 2)) +SELECT * FROM x WHERE y IN (WITH z AS (SELECT 1) SELECT * FROM z) +SELECT a FROM test WHERE (a > 1) +SELECT a FROM test WHERE a > (SELECT 1 FROM x GROUP BY y) +SELECT a FROM test WHERE EXISTS(SELECT 1) +SELECT a FROM test WHERE EXISTS(SELECT * FROM x UNION SELECT * FROM Y) OR TRUE +SELECT a FROM test WHERE TRUE OR NOT EXISTS(SELECT * FROM x) +SELECT a AS any, b AS some, c AS all, d AS exists FROM test WHERE a = ANY (SELECT 1) +SELECT a FROM test WHERE a > ALL (SELECT 1) +SELECT a FROM test WHERE (a, b) IN (SELECT 1, 2) +SELECT a FROM test ORDER BY a +SELECT a FROM test ORDER BY a, b +SELECT x FROM tests ORDER BY a DESC, b DESC, c +SELECT a FROM test ORDER BY a > 1 +SELECT * FROM test ORDER BY DATE DESC, TIMESTAMP DESC +SELECT * FROM test DISTRIBUTE BY y SORT BY x DESC ORDER BY l +SELECT * FROM test CLUSTER BY y +SELECT * FROM test CLUSTER BY y +SELECT * FROM test WHERE RAND() <= 0.1 DISTRIBUTE BY RAND() SORT BY RAND() +SELECT a, b FROM test GROUP BY 1 +SELECT a, b FROM test GROUP BY a +SELECT a, b FROM test WHERE a = 1 GROUP BY a HAVING a = 2 +SELECT a, b FROM test WHERE a = 1 GROUP BY a HAVING a = 2 ORDER BY a +SELECT a, b FROM test WHERE a = 1 GROUP BY CASE 1 WHEN 1 THEN 1 END +SELECT a FROM test GROUP BY GROUPING SETS (()) +SELECT a FROM test GROUP BY GROUPING SETS (x, ()) +SELECT a FROM test GROUP BY GROUPING SETS (x, (x, y), (x, y, z), q) +SELECT a FROM test GROUP BY CUBE (x) +SELECT a FROM test GROUP BY ROLLUP (x) +SELECT a FROM test GROUP BY CUBE (x) ROLLUP (x, y, z) +SELECT CASE WHEN a < b THEN 1 WHEN a < c THEN 2 ELSE 3 END FROM test +SELECT CASE 1 WHEN 1 THEN 1 ELSE 2 END +SELECT CASE 1 WHEN 1 THEN MAP('a', 'b') ELSE MAP('b', 'c') END['a'] +SELECT CASE 1 + 2 WHEN 1 THEN 1 ELSE 2 END +SELECT CASE TEST(1) + x[0] WHEN 1 THEN 1 ELSE 2 END +SELECT CASE x[0] WHEN 1 THEN 1 ELSE 2 END +SELECT CASE a.b WHEN 1 THEN 1 ELSE 2 END +SELECT CASE CASE x > 1 WHEN TRUE THEN 1 END WHEN 1 THEN 1 ELSE 2 END +SELECT a FROM (SELECT a FROM test) AS x +SELECT a FROM (SELECT a FROM (SELECT a FROM test) AS y) AS x +SELECT a FROM test WHERE a IN (1, 2, 3) OR b BETWEEN 1 AND 4 +SELECT a FROM test AS x TABLESAMPLE(BUCKET 1 OUT OF 5) +SELECT a FROM test TABLESAMPLE(BUCKET 1 OUT OF 5) +SELECT a FROM test TABLESAMPLE(BUCKET 1 OUT OF 5 ON x) +SELECT a FROM test TABLESAMPLE(BUCKET 1 OUT OF 5 ON RAND()) +SELECT a FROM test TABLESAMPLE(0.1 PERCENT) +SELECT a FROM test TABLESAMPLE(100) +SELECT a FROM test TABLESAMPLE(100 ROWS) +SELECT a FROM test TABLESAMPLE BERNOULLI (50) +SELECT a FROM test TABLESAMPLE SYSTEM (75) +SELECT ABS(a) FROM test +SELECT AVG(a) FROM test +SELECT CEIL(a) FROM test +SELECT COUNT(a) FROM test +SELECT COUNT(1) FROM test +SELECT COUNT(*) FROM test +SELECT COUNT(DISTINCT a) FROM test +SELECT EXP(a) FROM test +SELECT FLOOR(a) FROM test +SELECT FIRST(a) FROM test +SELECT GREATEST(a, b, c) FROM test +SELECT LAST(a) FROM test +SELECT LN(a) FROM test +SELECT LOG10(a) FROM test +SELECT MAX(a) FROM test +SELECT MIN(a) FROM test +SELECT POWER(a, 2) FROM test +SELECT QUANTILE(a, 0.95) FROM test +SELECT ROUND(a) FROM test +SELECT ROUND(a, 2) FROM test +SELECT SUM(a) FROM test +SELECT SQRT(a) FROM test +SELECT STDDEV(a) FROM test +SELECT STDDEV_POP(a) FROM test +SELECT STDDEV_SAMP(a) FROM test +SELECT VARIANCE(a) FROM test +SELECT VARIANCE_POP(a) FROM test +SELECT CAST(a AS INT) FROM test +SELECT CAST(a AS DATETIME) FROM test +SELECT CAST(a AS VARCHAR) FROM test +SELECT CAST(a < 1 AS INT) FROM test +SELECT CAST(a IS NULL AS INT) FROM test +SELECT COUNT(CAST(1 < 2 AS INT)) FROM test +SELECT COUNT(CASE WHEN CAST(1 < 2 AS BOOLEAN) THEN 1 END) FROM test +SELECT CAST(a AS DECIMAL) FROM test +SELECT CAST(a AS DECIMAL(1)) FROM test +SELECT CAST(a AS DECIMAL(1, 2)) FROM test +SELECT CAST(a AS MAP) FROM test +SELECT CAST(a AS TIMESTAMP) FROM test +SELECT CAST(a AS DATE) FROM test +SELECT CAST(a AS ARRAY) FROM test +SELECT TRY_CAST(a AS INT) FROM test +SELECT COALESCE(a, b, c) FROM test +SELECT IFNULL(a, b) FROM test +SELECT ANY_VALUE(a) FROM test +SELECT 1 FROM a JOIN b ON a.x = b.x +SELECT 1 FROM a JOIN b AS c ON a.x = b.x +SELECT 1 FROM a INNER JOIN b ON a.x = b.x +SELECT 1 FROM a LEFT JOIN b ON a.x = b.x +SELECT 1 FROM a RIGHT JOIN b ON a.x = b.x +SELECT 1 FROM a CROSS JOIN b ON a.x = b.x +SELECT 1 FROM a JOIN b USING (x) +SELECT 1 FROM a JOIN b USING (x, y, z) +SELECT 1 FROM a JOIN (SELECT a FROM c) AS b ON a.x = b.x AND a.x < 2 +SELECT 1 FROM a UNION SELECT 2 FROM b +SELECT 1 FROM a UNION ALL SELECT 2 FROM b +SELECT 1 FROM a JOIN b ON a.foo = b.bar JOIN c ON a.foo = c.bar +SELECT 1 FROM a LEFT JOIN b ON a.foo = b.bar JOIN c ON a.foo = c.bar +SELECT 1 FROM a LEFT INNER JOIN b ON a.foo = b.bar +SELECT 1 FROM a LEFT OUTER JOIN b ON a.foo = b.bar +SELECT 1 FROM a OUTER JOIN b ON a.foo = b.bar +SELECT 1 FROM a FULL JOIN b ON a.foo = b.bar +SELECT 1 UNION ALL SELECT 2 +SELECT 1 EXCEPT SELECT 2 +SELECT 1 EXCEPT SELECT 2 +SELECT 1 INTERSECT SELECT 2 +SELECT 1 INTERSECT SELECT 2 +SELECT 1 AS delete, 2 AS alter +SELECT * FROM (x) +SELECT * FROM ((x)) +SELECT * FROM ((SELECT 1)) +SELECT * FROM (SELECT 1) AS x +SELECT * FROM (SELECT 1 UNION SELECT 2) AS x +SELECT * FROM (SELECT 1 UNION ALL SELECT 2) AS x +SELECT * FROM (SELECT 1 UNION ALL SELECT 2) +SELECT * FROM ((SELECT 1) AS a UNION ALL (SELECT 2) AS b) +SELECT * FROM ((SELECT 1) AS a(b)) +SELECT * FROM x AS y(a, b) +SELECT * EXCEPT (a, b) +SELECT * REPLACE (a AS b, b AS C) +SELECT * REPLACE (a + 1 AS b, b AS C) +SELECT * EXCEPT (a, b) REPLACE (a AS b, b AS C) +SELECT a.* EXCEPT (a, b), b.* REPLACE (a AS b, b AS C) +SELECT zoo, animals FROM (VALUES ('oakland', ARRAY('a', 'b')), ('sf', ARRAY('b', 'c'))) AS t(zoo, animals) +WITH a AS (SELECT 1) SELECT 1 UNION ALL SELECT 2 +WITH a AS (SELECT 1) SELECT 1 UNION SELECT 2 +WITH a AS (SELECT 1) SELECT 1 INTERSECT SELECT 2 +WITH a AS (SELECT 1) SELECT 1 EXCEPT SELECT 2 +WITH a AS (SELECT 1) SELECT 1 EXCEPT SELECT 2 +(SELECT 1) UNION (SELECT 2) +(SELECT 1) UNION SELECT 2 +SELECT 1 UNION (SELECT 2) +(SELECT 1) ORDER BY x LIMIT 1 OFFSET 1 +(SELECT 1 UNION SELECT 2) UNION (SELECT 2 UNION ALL SELECT 3) +(SELECT 1 UNION SELECT 2) ORDER BY x LIMIT 1 OFFSET 1 +(SELECT 1 UNION SELECT 2) CLUSTER BY y DESC +(SELECT 1 UNION SELECT 2) SORT BY z +(SELECT 1 UNION SELECT 2) DISTRIBUTE BY z +(SELECT 1 UNION SELECT 2) DISTRIBUTE BY z SORT BY x +SELECT 1 UNION (SELECT 2) ORDER BY x +(SELECT 1) UNION SELECT 2 ORDER BY x +SELECT * FROM (((SELECT 1) UNION SELECT 2) ORDER BY x LIMIT 1 OFFSET 1) +SELECT * FROM ((SELECT 1 AS x) CROSS JOIN (SELECT 2 AS y)) AS z +((SELECT 1) EXCEPT (SELECT 2)) +VALUES (1) UNION SELECT * FROM x +WITH a AS (SELECT 1) SELECT a.* FROM a +WITH a AS (SELECT 1), b AS (SELECT 2) SELECT a.*, b.* FROM a CROSS JOIN b +WITH a AS (WITH b AS (SELECT 1 AS x) SELECT b.x FROM b) SELECT a.x FROM a +WITH RECURSIVE T(n) AS (VALUES (1) UNION ALL SELECT n + 1 FROM t WHERE n < 100) SELECT SUM(n) FROM t +WITH RECURSIVE T(n, m) AS (VALUES (1, 2) UNION ALL SELECT n + 1, n + 2 FROM t) SELECT SUM(n) FROM t +WITH baz AS (SELECT 1 AS col) UPDATE bar SET cid = baz.col1 FROM baz +SELECT * FROM (WITH y AS (SELECT 1 AS z) SELECT z FROM y) AS x +SELECT RANK() OVER () FROM x +SELECT RANK() OVER () AS y FROM x +SELECT RANK() OVER (PARTITION BY a) FROM x +SELECT RANK() OVER (PARTITION BY a, b) FROM x +SELECT RANK() OVER (ORDER BY a) FROM x +SELECT RANK() OVER (ORDER BY a, b) FROM x +SELECT RANK() OVER (PARTITION BY a ORDER BY a) FROM x +SELECT RANK() OVER (PARTITION BY a, b ORDER BY a, b DESC) FROM x +SELECT SUM(x) OVER (PARTITION BY a) AS y FROM x +SELECT SUM(x) OVER (PARTITION BY a ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) +SELECT SUM(x) OVER (PARTITION BY a ORDER BY b ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) +SELECT SUM(x) OVER (PARTITION BY a ORDER BY b ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) +SELECT SUM(x) OVER (PARTITION BY a ORDER BY b RANGE BETWEEN INTERVAL '1' DAY PRECEDING AND CURRENT ROW) +SELECT SUM(x) OVER (PARTITION BY a ORDER BY b RANGE BETWEEN INTERVAL '1' DAY PRECEDING AND INTERVAL '2' DAYS FOLLOWING) +SELECT SUM(x) OVER (PARTITION BY a ORDER BY b RANGE BETWEEN INTERVAL '1' DAY PRECEDING AND UNBOUNDED FOLLOWING) +SELECT SUM(x) OVER (PARTITION BY a ROWS BETWEEN UNBOUNDED PRECEDING AND PRECEDING) +SELECT SUM(x) OVER (PARTITION BY a ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) +SELECT SUM(x) OVER (PARTITION BY a ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING) +SELECT SUM(x) OVER (PARTITION BY a RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) +SELECT SUM(x) OVER (PARTITION BY a RANGE BETWEEN 1 AND 3) +SELECT SUM(x) OVER (PARTITION BY a RANGE BETWEEN 1 FOLLOWING AND 3) +SELECT SUM(x) OVER (PARTITION BY a RANGE BETWEEN 1 FOLLOWING AND UNBOUNDED FOLLOWING) +SELECT LISTAGG(x) WITHIN GROUP (ORDER BY x) AS y +SELECT LISTAGG(x) WITHIN GROUP (ORDER BY x DESC) +SELECT SUM(x) FILTER(WHERE x > 1) +SELECT SUM(x) FILTER(WHERE x > 1) OVER (ORDER BY y) +SELECT COUNT(DISTINCT a) OVER (PARTITION BY c ORDER BY d ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) +SELECT a['1'], b[0], x.c[0], "x".d['1'] FROM x +SELECT ARRAY(1, 2, 3) FROM x +SELECT ARRAY(ARRAY(1), ARRAY(2)) FROM x +SELECT MAP[ARRAY(1), ARRAY(2)] FROM x +SELECT MAP(ARRAY(1), ARRAY(2)) FROM x +SELECT MAX(ARRAY(1, 2, 3)) FROM x +SELECT ARRAY(ARRAY(0))[0][0] FROM x +SELECT MAP[ARRAY('x'), ARRAY(0)]['x'] FROM x +SELECT student, score FROM tests LATERAL VIEW EXPLODE(scores) +SELECT student, score FROM tests LATERAL VIEW EXPLODE(scores) AS score +SELECT student, score FROM tests LATERAL VIEW EXPLODE(scores) t AS score +SELECT student, score FROM tests LATERAL VIEW EXPLODE(scores) t AS score, name +SELECT student, score FROM tests LATERAL VIEW OUTER EXPLODE(scores) t AS score, name +SELECT tf.* FROM (SELECT 0) AS t LATERAL VIEW STACK(1, 2) tf +SELECT tf.* FROM (SELECT 0) AS t LATERAL VIEW STACK(1, 2) tf AS col0, col1, col2 +SELECT student, score FROM tests CROSS JOIN UNNEST(scores) AS t(score) +SELECT student, score FROM tests CROSS JOIN UNNEST(scores) AS t(a, b) +SELECT student, score FROM tests CROSS JOIN UNNEST(scores) WITH ORDINALITY AS t(a, b) +SELECT student, score FROM tests CROSS JOIN UNNEST(x.scores) AS t(score) +SELECT student, score FROM tests CROSS JOIN UNNEST(ARRAY(x.scores)) AS t(score) +CREATE TABLE a.b AS SELECT 1 +CREATE TABLE a.b AS SELECT a FROM a.c +CREATE TABLE IF NOT EXISTS x AS SELECT a FROM d +CREATE TEMPORARY TABLE x AS SELECT a FROM d +CREATE TEMPORARY TABLE IF NOT EXISTS x AS SELECT a FROM d +CREATE VIEW x AS SELECT a FROM b +CREATE VIEW IF NOT EXISTS x AS SELECT a FROM b +CREATE OR REPLACE VIEW x AS SELECT * +CREATE OR REPLACE TEMPORARY VIEW x AS SELECT * +CREATE TEMPORARY VIEW x AS SELECT a FROM d +CREATE TEMPORARY VIEW IF NOT EXISTS x AS SELECT a FROM d +CREATE TEMPORARY VIEW x AS WITH y AS (SELECT 1) SELECT * FROM y +CREATE TABLE z (a INT, b VARCHAR, c VARCHAR(100), d DECIMAL(5, 3)) +CREATE TABLE z (a ARRAY, b MAP, c DECIMAL(5, 3)) +CREATE TABLE z (a INT, b VARCHAR COMMENT 'z', c VARCHAR(100) COMMENT 'z', d DECIMAL(5, 3)) +CREATE TABLE z (a INT(11) DEFAULT UUID()) +CREATE TABLE z (a INT(11) DEFAULT NULL COMMENT '客户id') +CREATE TABLE z (a INT(11) NOT NULL DEFAULT 1) +CREATE TABLE z (a INT(11) NOT NULL COLLATE utf8_bin AUTO_INCREMENT) +CREATE TABLE z (a INT, PRIMARY KEY(a)) +CREATE TABLE z (a INT) ENGINE=InnoDB AUTO_INCREMENT=1 CHARACTER SET=utf8 COLLATE=utf8_bin COMMENT='x' +CREATE TABLE z (a INT) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARACTER SET=utf8 COLLATE=utf8_bin COMMENT='x' +CREATE TABLE z (a INT DEFAULT NULL, PRIMARY KEY(a)) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARACTER SET=utf8 COLLATE=utf8_bin COMMENT='x' +CREATE TABLE z WITH (FORMAT='parquet') AS SELECT 1 +CREATE TABLE z WITH (FORMAT='ORC', x = '2') AS SELECT 1 +CREATE TABLE z WITH (TABLE_FORMAT='iceberg', FORMAT='parquet') AS SELECT 1 +CREATE TABLE z WITH (TABLE_FORMAT='iceberg', FORMAT='ORC', x = '2') AS SELECT 1 +CREATE TABLE z (z INT) WITH (PARTITIONED_BY=(x INT, y INT)) +CREATE TABLE z (z INT) WITH (PARTITIONED_BY=(x INT)) AS SELECT 1 +CREATE TABLE z AS (WITH cte AS (SELECT 1) SELECT * FROM cte) +CREATE TABLE z AS ((WITH cte AS (SELECT 1) SELECT * FROM cte)) +CREATE TABLE z (a INT UNIQUE) +CREATE TABLE z (a INT AUTO_INCREMENT) +CREATE TABLE z (a INT UNIQUE AUTO_INCREMENT) +CREATE TEMPORARY FUNCTION f +CREATE TEMPORARY FUNCTION f AS 'g' +CREATE FUNCTION f +CREATE FUNCTION f AS 'g' +CREATE INDEX abc ON t (a) +CREATE INDEX abc ON t (a, b, b) +CREATE UNIQUE INDEX abc ON t (a, b, b) +CREATE UNIQUE INDEX IF NOT EXISTS my_idx ON tbl (a, b) +CACHE TABLE x +CACHE LAZY TABLE x +CACHE LAZY TABLE x OPTIONS('storageLevel' = 'value') +CACHE LAZY TABLE x OPTIONS('storageLevel' = 'value') AS SELECT 1 +CACHE LAZY TABLE x OPTIONS('storageLevel' = 'value') AS WITH a AS (SELECT 1) SELECT a.* FROM a +CACHE LAZY TABLE x AS WITH a AS (SELECT 1) SELECT a.* FROM a +CACHE TABLE x AS WITH a AS (SELECT 1) SELECT a.* FROM a +CALL catalog.system.iceberg_procedure_name(named_arg_1 => 'arg_1', named_arg_2 => 'arg_2') +INSERT OVERWRITE TABLE a.b PARTITION(ds) SELECT x FROM y +INSERT OVERWRITE TABLE a.b PARTITION(ds='YYYY-MM-DD') SELECT x FROM y +INSERT OVERWRITE TABLE a.b PARTITION(ds, hour) SELECT x FROM y +INSERT OVERWRITE TABLE a.b PARTITION(ds='YYYY-MM-DD', hour='hh') SELECT x FROM y +ALTER TYPE electronic_mail RENAME TO email +ANALYZE a.y +DELETE FROM x WHERE y > 1 +DELETE FROM y +DROP TABLE a +DROP TABLE a.b +DROP TABLE IF EXISTS a +DROP TABLE IF EXISTS a.b +DROP VIEW a +DROP VIEW a.b +DROP VIEW IF EXISTS a +DROP VIEW IF EXISTS a.b +SHOW TABLES +EXPLAIN SELECT * FROM x +INSERT INTO x SELECT * FROM y +INSERT INTO x (SELECT * FROM y) +INSERT INTO x WITH y AS (SELECT 1) SELECT * FROM y +INSERT INTO x.z IF EXISTS SELECT * FROM y +INSERT INTO x VALUES (1, 'a', 2.0) +INSERT INTO x VALUES (1, 'a', 2.0), (1, 'a', 3.0), (X(), y[1], z.x) +INSERT INTO y (a, b, c) SELECT a, b, c FROM x +INSERT OVERWRITE TABLE x IF EXISTS SELECT * FROM y +INSERT OVERWRITE TABLE a.b IF EXISTS SELECT * FROM y +SELECT 1 FROM PARQUET_SCAN('/x/y/*') AS y +UNCACHE TABLE x +UNCACHE TABLE IF EXISTS x +UPDATE tbl_name SET foo = 123 +UPDATE tbl_name SET foo = 123, bar = 345 +UPDATE db.tbl_name SET foo = 123 WHERE tbl_name.bar = 234 +UPDATE db.tbl_name SET foo = 123, foo_1 = 234 WHERE tbl_name.bar = 234 +TRUNCATE TABLE x +OPTIMIZE TABLE y +WITH a AS (SELECT 1) INSERT INTO b SELECT * FROM a +WITH a AS (SELECT * FROM b) UPDATE a SET col = 1 +WITH a AS (SELECT * FROM b) CREATE TABLE b AS SELECT * FROM a +WITH a AS (SELECT * FROM b) DELETE FROM a +WITH a AS (SELECT * FROM b) CACHE TABLE a +SELECT ? AS ? FROM x WHERE b BETWEEN ? AND ? GROUP BY ?, 1 LIMIT ? +WITH a AS ((SELECT b.foo AS foo, b.bar AS bar FROM b) UNION ALL (SELECT c.foo AS foo, c.bar AS bar FROM c)) SELECT * FROM a +WITH a AS ((SELECT 1 AS b) UNION ALL (SELECT 1 AS b)) SELECT * FROM a +SELECT (WITH x AS (SELECT 1 AS y) SELECT * FROM x) AS z diff --git a/tests/fixtures/optimizer/eliminate_subqueries.sql b/tests/fixtures/optimizer/eliminate_subqueries.sql new file mode 100644 index 0000000..aae5f2a --- /dev/null +++ b/tests/fixtures/optimizer/eliminate_subqueries.sql @@ -0,0 +1,42 @@ +SELECT 1 AS x, 2 AS y +UNION ALL +SELECT 1 AS x, 2 AS y; +WITH _e_0 AS ( + SELECT + 1 AS x, + 2 AS y +) +SELECT + * +FROM _e_0 +UNION ALL +SELECT + * +FROM _e_0; + +SELECT x.id +FROM ( + SELECT * + FROM x AS x + JOIN y AS y + ON x.id = y.id +) AS x +JOIN ( + SELECT * + FROM x AS x + JOIN y AS y + ON x.id = y.id +) AS y +ON x.id = y.id; +WITH _e_0 AS ( + SELECT + * + FROM x AS x + JOIN y AS y + ON x.id = y.id +) +SELECT + x.id +FROM "_e_0" AS x +JOIN "_e_0" AS y + ON x.id = y.id; diff --git a/tests/fixtures/optimizer/expand_multi_table_selects.sql b/tests/fixtures/optimizer/expand_multi_table_selects.sql new file mode 100644 index 0000000..a5a4664 --- /dev/null +++ b/tests/fixtures/optimizer/expand_multi_table_selects.sql @@ -0,0 +1,11 @@ +-------------------------------------- +-- Multi Table Selects +-------------------------------------- +SELECT * FROM x AS x, y AS y WHERE x.a = y.a; +SELECT * FROM x AS x CROSS JOIN y AS y WHERE x.a = y.a; + +SELECT * FROM x AS x, y AS y WHERE x.a = y.a AND x.a = 1 and y.b = 1; +SELECT * FROM x AS x CROSS JOIN y AS y WHERE x.a = y.a AND x.a = 1 AND y.b = 1; + +SELECT * FROM x AS x, y AS y WHERE x.a > y.a; +SELECT * FROM x AS x CROSS JOIN y AS y WHERE x.a > y.a; diff --git a/tests/fixtures/optimizer/isolate_table_selects.sql b/tests/fixtures/optimizer/isolate_table_selects.sql new file mode 100644 index 0000000..3b9a938 --- /dev/null +++ b/tests/fixtures/optimizer/isolate_table_selects.sql @@ -0,0 +1,20 @@ +SELECT * FROM x AS x, y AS y2; +SELECT * FROM (SELECT * FROM x AS x) AS x, (SELECT * FROM y AS y) AS y2; + +SELECT * FROM x AS x WHERE x = 1; +SELECT * FROM x AS x WHERE x = 1; + +SELECT * FROM x AS x JOIN y AS y; +SELECT * FROM (SELECT * FROM x AS x) AS x JOIN (SELECT * FROM y AS y) AS y; + +SELECT * FROM (SELECT 1) AS x JOIN y AS y; +SELECT * FROM (SELECT 1) AS x JOIN (SELECT * FROM y AS y) AS y; + +SELECT * FROM x AS x JOIN (SELECT * FROM y) AS y; +SELECT * FROM (SELECT * FROM x AS x) AS x JOIN (SELECT * FROM y) AS y; + +WITH y AS (SELECT *) SELECT * FROM x AS x; +WITH y AS (SELECT *) SELECT * FROM x AS x; + +WITH y AS (SELECT * FROM y AS y2 JOIN x AS z2) SELECT * FROM x AS x JOIN y as y; +WITH y AS (SELECT * FROM (SELECT * FROM y AS y) AS y2 JOIN (SELECT * FROM x AS x) AS z2) SELECT * FROM (SELECT * FROM x AS x) AS x JOIN y AS y; diff --git a/tests/fixtures/optimizer/normalize.sql b/tests/fixtures/optimizer/normalize.sql new file mode 100644 index 0000000..a84fadf --- /dev/null +++ b/tests/fixtures/optimizer/normalize.sql @@ -0,0 +1,41 @@ +(A OR B) AND (B OR C) AND (E OR F); +(A OR B) AND (B OR C) AND (E OR F); + +(A AND B) OR (B AND C AND D); +(A OR C) AND (A OR D) AND B; + +(A OR B) AND (A OR C) AND (A OR D) AND (B OR C) AND (B OR D) AND B; +(A OR C) AND (A OR D) AND B; + +(A AND E) OR (B AND C) OR (D AND (E OR F)); +(A OR B OR D) AND (A OR C OR D) AND (B OR D OR E) AND (B OR E OR F) AND (C OR D OR E) AND (C OR E OR F); + +(A AND B AND C AND D AND E AND F AND G) OR (H AND I AND J AND K AND L AND M AND N) OR (O AND P AND Q); +(A AND B AND C AND D AND E AND F AND G) OR (H AND I AND J AND K AND L AND M AND N) OR (O AND P AND Q); + +NOT NOT NOT (A OR B); +NOT A AND NOT B; + +A OR B; +A OR B; + +A AND (B AND C); +A AND B AND C; + +A OR (B AND C); +(A OR B) AND (A OR C); + +(A AND B) OR C; +(A OR C) AND (B OR C); + +A OR (B OR (C AND D)); +(A OR B OR C) AND (A OR B OR D); + +A OR ((((B OR C) AND (B OR D)) OR C) AND (((B OR C) AND (B OR D)) OR D)); +(A OR B OR C) AND (A OR B OR D); + +(A AND B) OR (C AND D); +(A OR C) AND (A OR D) AND (B OR C) AND (B OR D); + +(A AND B) OR (C OR (D AND E)); +(A OR C OR D) AND (A OR C OR E) AND (B OR C OR D) AND (B OR C OR E); diff --git a/tests/fixtures/optimizer/optimize_joins.sql b/tests/fixtures/optimizer/optimize_joins.sql new file mode 100644 index 0000000..b64544e --- /dev/null +++ b/tests/fixtures/optimizer/optimize_joins.sql @@ -0,0 +1,20 @@ +SELECT * FROM x JOIN y ON y.a = 1 JOIN z ON x.a = z.a AND y.a = z.a; +SELECT * FROM x JOIN z ON x.a = z.a AND TRUE JOIN y ON y.a = 1 AND y.a = z.a; + +SELECT * FROM x JOIN y ON y.a = 1 JOIN z ON x.a = z.a; +SELECT * FROM x JOIN y ON y.a = 1 JOIN z ON x.a = z.a; + +SELECT * FROM x CROSS JOIN y JOIN z ON x.a = z.a AND y.a = z.a; +SELECT * FROM x JOIN z ON x.a = z.a AND TRUE JOIN y ON y.a = z.a; + +SELECT * FROM x LEFT JOIN y ON y.a = 1 JOIN z ON x.a = z.a AND y.a = z.a; +SELECT * FROM x JOIN z ON x.a = z.a AND TRUE LEFT JOIN y ON y.a = 1 AND y.a = z.a; + +SELECT * FROM x INNER JOIN z; +SELECT * FROM x JOIN z; + +SELECT * FROM x LEFT OUTER JOIN z; +SELECT * FROM x LEFT JOIN z; + +SELECT * FROM x CROSS JOIN z; +SELECT * FROM x CROSS JOIN z; diff --git a/tests/fixtures/optimizer/optimizer.sql b/tests/fixtures/optimizer/optimizer.sql new file mode 100644 index 0000000..f7bbdda --- /dev/null +++ b/tests/fixtures/optimizer/optimizer.sql @@ -0,0 +1,148 @@ +SELECT a, m FROM z LATERAL VIEW EXPLODE([1, 2]) q AS m; +SELECT + "z"."a" AS "a", + "q"."m" AS "m" +FROM ( + SELECT + "z"."a" AS "a" + FROM "z" AS "z" +) AS "z" +LATERAL VIEW +EXPLODE(ARRAY(1, 2)) q AS "m"; + +SELECT x FROM UNNEST([1, 2]) AS q(x, y); +SELECT + "q"."x" AS "x" +FROM UNNEST(ARRAY(1, 2)) AS "q"("x", "y"); + +WITH cte AS ( + ( + SELECT + a + FROM + x + ) + UNION ALL + ( + SELECT + a + FROM + y + ) +) +SELECT + * +FROM + cte; +WITH "cte" AS ( + ( + SELECT + "x"."a" AS "a" + FROM "x" AS "x" + ) + UNION ALL + ( + SELECT + "y"."a" AS "a" + FROM "y" AS "y" + ) +) +SELECT + "cte"."a" AS "a" +FROM "cte"; + +WITH cte1 AS ( + SELECT a + FROM x +), cte2 AS ( + SELECT a + 1 AS a + FROM cte1 +) +SELECT + a +FROM cte1 +UNION ALL +SELECT + a +FROM cte2; +WITH "cte1" AS ( + SELECT + "x"."a" AS "a" + FROM "x" AS "x" +), "cte2" AS ( + SELECT + "cte1"."a" + 1 AS "a" + FROM "cte1" +) +SELECT + "cte1"."a" AS "a" +FROM "cte1" +UNION ALL +SELECT + "cte2"."a" AS "a" +FROM "cte2"; + +SELECT a, SUM(b) +FROM ( + SELECT x.a, y.b + FROM x, y + WHERE (SELECT max(b) FROM y WHERE x.a = y.a) >= 0 AND x.a = y.a +) d +WHERE (TRUE AND TRUE OR 'a' = 'b') AND a > 1 +GROUP BY a; +SELECT + "d"."a" AS "a", + SUM("d"."b") AS "_col_1" +FROM ( + SELECT + "x"."a" AS "a", + "y"."b" AS "b" + FROM ( + SELECT + "x"."a" AS "a" + FROM "x" AS "x" + WHERE + "x"."a" > 1 + ) AS "x" + LEFT JOIN ( + SELECT + MAX("y"."b") AS "_col_0", + "y"."a" AS "_u_1" + FROM "y" AS "y" + GROUP BY + "y"."a" + ) AS "_u_0" + ON "x"."a" = "_u_0"."_u_1" + JOIN ( + SELECT + "y"."a" AS "a", + "y"."b" AS "b" + FROM "y" AS "y" + ) AS "y" + ON "x"."a" = "y"."a" + WHERE + "_u_0"."_col_0" >= 0 + AND NOT "_u_0"."_u_1" IS NULL +) AS "d" +GROUP BY + "d"."a"; + +(SELECT a FROM x) LIMIT 1; +( + SELECT + "x"."a" AS "a" + FROM "x" AS "x" +) +LIMIT 1; + +(SELECT b FROM x UNION SELECT b FROM y) LIMIT 1; +( + SELECT + "x"."b" AS "b" + FROM "x" AS "x" + UNION + SELECT + "y"."b" AS "b" + FROM "y" AS "y" +) +LIMIT 1; diff --git a/tests/fixtures/optimizer/pushdown_predicates.sql b/tests/fixtures/optimizer/pushdown_predicates.sql new file mode 100644 index 0000000..676cb96 --- /dev/null +++ b/tests/fixtures/optimizer/pushdown_predicates.sql @@ -0,0 +1,32 @@ +SELECT x.a AS a FROM (SELECT x.a FROM x AS x) AS x JOIN y WHERE x.a = 1 AND x.b = 1 AND y.a = 1; +SELECT x.a AS a FROM (SELECT x.a FROM x AS x WHERE x.a = 1 AND x.b = 1) AS x JOIN y ON y.a = 1 WHERE TRUE AND TRUE AND TRUE; + +WITH x AS (SELECT y.a FROM y) SELECT * FROM x WHERE x.a = 1; +WITH x AS (SELECT y.a FROM y WHERE y.a = 1) SELECT * FROM x WHERE TRUE; + +SELECT x.a FROM (SELECT * FROM x) AS x JOIN y WHERE y.a = 1 OR (x.a = 1 AND x.b = 1); +SELECT x.a FROM (SELECT * FROM x) AS x JOIN y WHERE (x.a = 1 AND x.b = 1) OR y.a = 1; + +SELECT x.a FROM (SELECT * FROM x) AS x JOIN y WHERE (x.a = y.a AND x.a = 1 AND x.b = 1) OR x.a = y.a; +SELECT x.a FROM (SELECT * FROM x) AS x JOIN y ON x.a = y.a WHERE TRUE; + +SELECT x.a FROM (SELECT * FROM x) AS x JOIN y WHERE (x.a = y.a AND x.a = 1 AND x.b = 1) OR x.a = y.b; +SELECT x.a FROM (SELECT * FROM x) AS x JOIN y ON x.a = y.a OR x.a = y.b WHERE (x.a = y.a AND x.a = 1 AND x.b = 1) OR x.a = y.b; + +SELECT x.a FROM (SELECT x.a AS a, x.b * 1 AS c FROM x) AS x WHERE x.c = 1; +SELECT x.a FROM (SELECT x.a AS a, x.b * 1 AS c FROM x WHERE x.b * 1 = 1) AS x WHERE TRUE; + +SELECT x.a FROM (SELECT x.a AS a, x.b * 1 AS c FROM x) AS x WHERE x.c = 1 or x.c = 2; +SELECT x.a FROM (SELECT x.a AS a, x.b * 1 AS c FROM x WHERE x.b * 1 = 1 OR x.b * 1 = 2) AS x WHERE TRUE; + +SELECT x.a AS a FROM (SELECT x.a FROM x AS x) AS x JOIN y WHERE x.a = 1 AND x.b = 1 AND (x.c = 1 OR y.c = 1); +SELECT x.a AS a FROM (SELECT x.a FROM x AS x WHERE x.a = 1 AND x.b = 1) AS x JOIN y ON x.c = 1 OR y.c = 1 WHERE TRUE AND TRUE AND (TRUE); + +SELECT x.a FROM x AS x JOIN (SELECT y.a FROM y AS y) AS y ON y.a = 1 AND x.a = y.a; +SELECT x.a FROM x AS x JOIN (SELECT y.a FROM y AS y WHERE y.a = 1) AS y ON x.a = y.a AND TRUE; + +SELECT x.a AS a FROM x AS x JOIN (SELECT * FROM y AS y) AS y ON y.a = 1 WHERE x.a = 1 AND x.b = 1 AND y.a = x; +SELECT x.a AS a FROM x AS x JOIN (SELECT * FROM y AS y WHERE y.a = 1) AS y ON y.a = x AND TRUE WHERE x.a = 1 AND x.b = 1 AND TRUE; + +SELECT x.a AS a FROM x AS x CROSS JOIN (SELECT * FROM y AS y) AS y WHERE x.a = 1 AND x.b = 1 AND y.a = x.a AND y.a = 1; +SELECT x.a AS a FROM x AS x JOIN (SELECT * FROM y AS y WHERE y.a = 1) AS y ON y.a = x.a AND TRUE WHERE x.a = 1 AND x.b = 1 AND TRUE AND TRUE; diff --git a/tests/fixtures/optimizer/pushdown_projections.sql b/tests/fixtures/optimizer/pushdown_projections.sql new file mode 100644 index 0000000..9deceb6 --- /dev/null +++ b/tests/fixtures/optimizer/pushdown_projections.sql @@ -0,0 +1,41 @@ +SELECT a FROM (SELECT * FROM x); +SELECT "_q_0".a AS a FROM (SELECT x.a AS a FROM x AS x) AS "_q_0"; + +SELECT 1 FROM (SELECT * FROM x) WHERE b = 2; +SELECT 1 AS "_col_0" FROM (SELECT x.b AS b FROM x AS x) AS "_q_0" WHERE "_q_0".b = 2; + +SELECT (SELECT c FROM y WHERE q.b = y.b) FROM (SELECT * FROM x) AS q; +SELECT (SELECT y.c AS c FROM y AS y WHERE q.b = y.b) AS "_col_0" FROM (SELECT x.b AS b FROM x AS x) AS q; + +SELECT a FROM x JOIN (SELECT b, c FROM y) AS z ON x.b = z.b; +SELECT x.a AS a FROM x AS x JOIN (SELECT y.b AS b FROM y AS y) AS z ON x.b = z.b; + +SELECT x1.a FROM (SELECT * FROM x) AS x1, (SELECT * FROM x) AS x2; +SELECT x1.a AS a FROM (SELECT x.a AS a FROM x AS x) AS x1, (SELECT 1 AS "_" FROM x AS x) AS x2; + +SELECT x1.a FROM (SELECT * FROM x) AS x1, (SELECT * FROM x) AS x2; +SELECT x1.a AS a FROM (SELECT x.a AS a FROM x AS x) AS x1, (SELECT 1 AS "_" FROM x AS x) AS x2; + +SELECT a FROM (SELECT DISTINCT a, b FROM x); +SELECT "_q_0".a AS a FROM (SELECT DISTINCT x.a AS a, x.b AS b FROM x AS x) AS "_q_0"; + +SELECT a FROM (SELECT a, b FROM x UNION ALL SELECT a, b FROM x); +SELECT "_q_0".a AS a FROM (SELECT x.a AS a FROM x AS x UNION ALL SELECT x.a AS a FROM x AS x) AS "_q_0"; + +SELECT a FROM (SELECT a, b FROM x UNION SELECT a, b FROM x); +SELECT "_q_0".a AS a FROM (SELECT x.a AS a, x.b AS b FROM x AS x UNION SELECT x.a AS a, x.b AS b FROM x AS x) AS "_q_0"; + +WITH y AS (SELECT * FROM x) SELECT a FROM y; +WITH y AS (SELECT x.a AS a FROM x AS x) SELECT y.a AS a FROM y; + +WITH z AS (SELECT * FROM x), q AS (SELECT b FROM z) SELECT b FROM q; +WITH z AS (SELECT x.b AS b FROM x AS x), q AS (SELECT z.b AS b FROM z) SELECT q.b AS b FROM q; + +WITH z AS (SELECT * FROM x) SELECT a FROM z UNION SELECT a FROM z; +WITH z AS (SELECT x.a AS a FROM x AS x) SELECT z.a AS a FROM z UNION SELECT z.a AS a FROM z; + +SELECT b FROM (SELECT a, SUM(b) AS b FROM x GROUP BY a); +SELECT "_q_0".b AS b FROM (SELECT SUM(x.b) AS b FROM x AS x GROUP BY x.a) AS "_q_0"; + +SELECT b FROM (SELECT a, SUM(b) AS b FROM x ORDER BY a); +SELECT "_q_0".b AS b FROM (SELECT x.a AS a, SUM(x.b) AS b FROM x AS x ORDER BY a) AS "_q_0"; diff --git a/tests/fixtures/optimizer/qualify_columns.sql b/tests/fixtures/optimizer/qualify_columns.sql new file mode 100644 index 0000000..004c57c --- /dev/null +++ b/tests/fixtures/optimizer/qualify_columns.sql @@ -0,0 +1,233 @@ +-------------------------------------- +-- Qualify columns +-------------------------------------- +SELECT a FROM x; +SELECT x.a AS a FROM x AS x; + +SELECT a FROM x AS z; +SELECT z.a AS a FROM x AS z; + +SELECT a AS a FROM x; +SELECT x.a AS a FROM x AS x; + +SELECT x.a FROM x; +SELECT x.a AS a FROM x AS x; + +SELECT x.a AS a FROM x; +SELECT x.a AS a FROM x AS x; + +SELECT a AS b FROM x; +SELECT x.a AS b FROM x AS x; + +SELECT 1, 2 FROM x; +SELECT 1 AS "_col_0", 2 AS "_col_1" FROM x AS x; + +SELECT a + b FROM x; +SELECT x.a + x.b AS "_col_0" FROM x AS x; + +SELECT a + b FROM x; +SELECT x.a + x.b AS "_col_0" FROM x AS x; + +SELECT a, SUM(b) FROM x WHERE a > 1 AND b > 1 GROUP BY a; +SELECT x.a AS a, SUM(x.b) AS "_col_1" FROM x AS x WHERE x.a > 1 AND x.b > 1 GROUP BY x.a; + +SELECT a AS j, b FROM x ORDER BY j; +SELECT x.a AS j, x.b AS b FROM x AS x ORDER BY j; + +SELECT a AS j, b FROM x GROUP BY j; +SELECT x.a AS j, x.b AS b FROM x AS x GROUP BY x.a; + +SELECT a, b FROM x GROUP BY 1, 2; +SELECT x.a AS a, x.b AS b FROM x AS x GROUP BY x.a, x.b; + +SELECT a, b FROM x ORDER BY 1, 2; +SELECT x.a AS a, x.b AS b FROM x AS x ORDER BY a, b; + +SELECT DATE(a), DATE(b) AS c FROM x GROUP BY 1, 2; +SELECT DATE(x.a) AS "_col_0", DATE(x.b) AS c FROM x AS x GROUP BY DATE(x.a), DATE(x.b); + +SELECT x.a AS c FROM x JOIN y ON x.b = y.b GROUP BY c; +SELECT x.a AS c FROM x AS x JOIN y AS y ON x.b = y.b GROUP BY y.c; + +SELECT DATE(x.a) AS d FROM x JOIN y ON x.b = y.b GROUP BY d; +SELECT DATE(x.a) AS d FROM x AS x JOIN y AS y ON x.b = y.b GROUP BY DATE(x.a); + +SELECT a AS a, b FROM x ORDER BY a; +SELECT x.a AS a, x.b AS b FROM x AS x ORDER BY a; + +SELECT a, b FROM x ORDER BY a; +SELECT x.a AS a, x.b AS b FROM x AS x ORDER BY a; + +SELECT a FROM x ORDER BY b; +SELECT x.a AS a FROM x AS x ORDER BY x.b; + +# dialect: bigquery +SELECT ROW_NUMBER() OVER (PARTITION BY a ORDER BY b) AS row_num FROM x QUALIFY row_num = 1; +SELECT ROW_NUMBER() OVER (PARTITION BY x.a ORDER BY x.b) AS row_num FROM x AS x QUALIFY row_num = 1; + +# dialect: bigquery +SELECT x.b, x.a FROM x LEFT JOIN y ON x.b = y.b QUALIFY ROW_NUMBER() OVER(PARTITION BY x.b ORDER BY x.a DESC) = 1; +SELECT x.b AS b, x.a AS a FROM x AS x LEFT JOIN y AS y ON x.b = y.b QUALIFY ROW_NUMBER() OVER (PARTITION BY x.b ORDER BY x.a DESC) = 1; + +-------------------------------------- +-- Derived tables +-------------------------------------- +SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x) AS y; +SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x) AS y; + +SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x) AS y(a); +SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x) AS y; + +SELECT y.c AS c FROM (SELECT x.a AS a, x.b AS b FROM x AS x) AS y(c); +SELECT y.c AS c FROM (SELECT x.a AS c, x.b AS b FROM x AS x) AS y; + +SELECT a FROM (SELECT a FROM x AS x) y; +SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x) AS y; + +SELECT a FROM (SELECT a AS a FROM x); +SELECT "_q_0".a AS a FROM (SELECT x.a AS a FROM x AS x) AS "_q_0"; + +SELECT a FROM (SELECT a FROM (SELECT a FROM x)); +SELECT "_q_1".a AS a FROM (SELECT "_q_0".a AS a FROM (SELECT x.a AS a FROM x AS x) AS "_q_0") AS "_q_1"; + +SELECT x.a FROM x AS x JOIN (SELECT * FROM x); +SELECT x.a AS a FROM x AS x JOIN (SELECT x.a AS a, x.b AS b FROM x AS x) AS "_q_0"; + +-------------------------------------- +-- Joins +-------------------------------------- +SELECT a, c FROM x JOIN y ON x.b = y.b; +SELECT x.a AS a, y.c AS c FROM x AS x JOIN y AS y ON x.b = y.b; + +SELECT a, c FROM x, y; +SELECT x.a AS a, y.c AS c FROM x AS x, y AS y; + +-------------------------------------- +-- Unions +-------------------------------------- +SELECT a FROM x UNION SELECT a FROM x; +SELECT x.a AS a FROM x AS x UNION SELECT x.a AS a FROM x AS x; + +SELECT a FROM x UNION SELECT a FROM x UNION SELECT a FROM x; +SELECT x.a AS a FROM x AS x UNION SELECT x.a AS a FROM x AS x UNION SELECT x.a AS a FROM x AS x; + +SELECT a FROM (SELECT a FROM x UNION SELECT a FROM x); +SELECT "_q_0".a AS a FROM (SELECT x.a AS a FROM x AS x UNION SELECT x.a AS a FROM x AS x) AS "_q_0"; + +-------------------------------------- +-- Subqueries +-------------------------------------- +SELECT a FROM x WHERE b IN (SELECT c FROM y); +SELECT x.a AS a FROM x AS x WHERE x.b IN (SELECT y.c AS c FROM y AS y); + +SELECT (SELECT c FROM y) FROM x; +SELECT (SELECT y.c AS c FROM y AS y) AS "_col_0" FROM x AS x; + +SELECT a FROM (SELECT a FROM x) WHERE a IN (SELECT b FROM (SELECT b FROM y)); +SELECT "_q_1".a AS a FROM (SELECT x.a AS a FROM x AS x) AS "_q_1" WHERE "_q_1".a IN (SELECT "_q_0".b AS b FROM (SELECT y.b AS b FROM y AS y) AS "_q_0"); + +-------------------------------------- +-- Correlated subqueries +-------------------------------------- +SELECT a FROM x WHERE b IN (SELECT c FROM y WHERE y.b = x.a); +SELECT x.a AS a FROM x AS x WHERE x.b IN (SELECT y.c AS c FROM y AS y WHERE y.b = x.a); + +SELECT a FROM x WHERE b IN (SELECT c FROM y WHERE y.b = a); +SELECT x.a AS a FROM x AS x WHERE x.b IN (SELECT y.c AS c FROM y AS y WHERE y.b = x.a); + +SELECT a FROM x WHERE b IN (SELECT b FROM y AS x); +SELECT x.a AS a FROM x AS x WHERE x.b IN (SELECT x.b AS b FROM y AS x); + +SELECT a FROM x AS i WHERE b IN (SELECT b FROM y AS j WHERE j.b IN (SELECT c FROM y AS k WHERE k.b = j.b)); +SELECT i.a AS a FROM x AS i WHERE i.b IN (SELECT j.b AS b FROM y AS j WHERE j.b IN (SELECT k.c AS c FROM y AS k WHERE k.b = j.b)); + +# dialect: bigquery +SELECT aa FROM x, UNNEST(a) AS aa; +SELECT aa AS aa FROM x AS x, UNNEST(x.a) AS aa; + +SELECT aa FROM x, UNNEST(a) AS t(aa); +SELECT t.aa AS aa FROM x AS x, UNNEST(x.a) AS t(aa); + +-------------------------------------- +-- Expand * +-------------------------------------- +SELECT * FROM x; +SELECT x.a AS a, x.b AS b FROM x AS x; + +SELECT x.* FROM x; +SELECT x.a AS a, x.b AS b FROM x AS x; + +SELECT * FROM x JOIN y ON x.b = y.b; +SELECT x.a AS a, x.b AS b, y.b AS b, y.c AS c FROM x AS x JOIN y AS y ON x.b = y.b; + +SELECT x.* FROM x JOIN y ON x.b = y.b; +SELECT x.a AS a, x.b AS b FROM x AS x JOIN y AS y ON x.b = y.b; + +SELECT x.*, y.* FROM x JOIN y ON x.b = y.b; +SELECT x.a AS a, x.b AS b, y.b AS b, y.c AS c FROM x AS x JOIN y AS y ON x.b = y.b; + +SELECT a FROM (SELECT * FROM x); +SELECT "_q_0".a AS a FROM (SELECT x.a AS a, x.b AS b FROM x AS x) AS "_q_0"; + +SELECT * FROM (SELECT a FROM x); +SELECT "_q_0".a AS a FROM (SELECT x.a AS a FROM x AS x) AS "_q_0"; + +-------------------------------------- +-- CTEs +-------------------------------------- +WITH z AS (SELECT x.a AS a FROM x) SELECT z.a AS a FROM z; +WITH z AS (SELECT x.a AS a FROM x AS x) SELECT z.a AS a FROM z; + +WITH z(a) AS (SELECT a FROM x) SELECT * FROM z; +WITH z AS (SELECT x.a AS a FROM x AS x) SELECT z.a AS a FROM z; + +WITH z AS (SELECT a FROM x) SELECT * FROM z as q; +WITH z AS (SELECT x.a AS a FROM x AS x) SELECT q.a AS a FROM z AS q; + +WITH z AS (SELECT a FROM x) SELECT * FROM z; +WITH z AS (SELECT x.a AS a FROM x AS x) SELECT z.a AS a FROM z; + +WITH z AS (SELECT a FROM x), q AS (SELECT * FROM z) SELECT * FROM q; +WITH z AS (SELECT x.a AS a FROM x AS x), q AS (SELECT z.a AS a FROM z) SELECT q.a AS a FROM q; + +WITH z AS (SELECT * FROM x) SELECT * FROM z UNION SELECT * FROM z; +WITH z AS (SELECT x.a AS a, x.b AS b FROM x AS x) SELECT z.a AS a, z.b AS b FROM z UNION SELECT z.a AS a, z.b AS b FROM z; + +WITH z AS (SELECT * FROM x), q AS (SELECT b FROM z) SELECT b FROM q; +WITH z AS (SELECT x.a AS a, x.b AS b FROM x AS x), q AS (SELECT z.b AS b FROM z) SELECT q.b AS b FROM q; + +WITH z AS ((SELECT b FROM x UNION ALL SELECT b FROM y) ORDER BY b) SELECT * FROM z; +WITH z AS ((SELECT x.b AS b FROM x AS x UNION ALL SELECT y.b AS b FROM y AS y) ORDER BY b) SELECT z.b AS b FROM z; + +-------------------------------------- +-- Except and Replace +-------------------------------------- +SELECT * REPLACE(a AS d) FROM x; +SELECT x.a AS d, x.b AS b FROM x AS x; + +SELECT * EXCEPT(b) REPLACE(a AS d) FROM x; +SELECT x.a AS d FROM x AS x; + +SELECT x.* EXCEPT(a), y.* FROM x, y; +SELECT x.b AS b, y.b AS b, y.c AS c FROM x AS x, y AS y; + +SELECT * EXCEPT(a) FROM x; +SELECT x.b AS b FROM x AS x; + +-------------------------------------- +-- Using +-------------------------------------- +SELECT x.b FROM x JOIN y USING (b); +SELECT x.b AS b FROM x AS x JOIN y AS y ON x.b = y.b; + +SELECT x.b FROM x JOIN y USING (b) JOIN z USING (b); +SELECT x.b AS b FROM x AS x JOIN y AS y ON x.b = y.b JOIN z AS z ON x.b = z.b; + +SELECT b FROM x AS x2 JOIN y AS y2 USING (b); +SELECT COALESCE(x2.b, y2.b) AS b FROM x AS x2 JOIN y AS y2 ON x2.b = y2.b; + +SELECT b FROM x JOIN y USING (b) WHERE b = 1 and y.b = 2; +SELECT COALESCE(x.b, y.b) AS b FROM x AS x JOIN y AS y ON x.b = y.b WHERE COALESCE(x.b, y.b) = 1 AND y.b = 2; + +SELECT b FROM x JOIN y USING (b) JOIN z USING (b); +SELECT COALESCE(x.b, y.b, z.b) AS b FROM x AS x JOIN y AS y ON x.b = y.b JOIN z AS z ON x.b = z.b; diff --git a/tests/fixtures/optimizer/qualify_columns__invalid.sql b/tests/fixtures/optimizer/qualify_columns__invalid.sql new file mode 100644 index 0000000..056b0e9 --- /dev/null +++ b/tests/fixtures/optimizer/qualify_columns__invalid.sql @@ -0,0 +1,14 @@ +SELECT a FROM zz; +SELECT * FROM zz; +SELECT z.a FROM x; +SELECT z.* FROM x; +SELECT x FROM x; +INSERT INTO x VALUES (1, 2); +SELECT a FROM x AS z JOIN y AS z; +WITH z AS (SELECT * FROM x) SELECT * FROM x AS z; +SELECT a FROM x JOIN (SELECT b FROM y WHERE y.b = x.c); +SELECT a FROM x AS y JOIN (SELECT a FROM y) AS q ON y.a = q.a; +SELECT q.a FROM (SELECT x.b FROM x) AS z JOIN (SELECT a FROM z) AS q ON z.b = q.a; +SELECT b FROM x AS a CROSS JOIN y AS b CROSS JOIN y AS c; +SELECT x.a FROM x JOIN y USING (a); +SELECT a, SUM(b) FROM x GROUP BY 3; diff --git a/tests/fixtures/optimizer/qualify_tables.sql b/tests/fixtures/optimizer/qualify_tables.sql new file mode 100644 index 0000000..2cea85d --- /dev/null +++ b/tests/fixtures/optimizer/qualify_tables.sql @@ -0,0 +1,17 @@ +SELECT 1 FROM z; +SELECT 1 FROM c.db.z AS z; + +SELECT 1 FROM y.z; +SELECT 1 FROM c.y.z AS z; + +SELECT 1 FROM x.y.z; +SELECT 1 FROM x.y.z AS z; + +SELECT 1 FROM x.y.z AS z; +SELECT 1 FROM x.y.z AS z; + +WITH a AS (SELECT 1 FROM z) SELECT 1 FROM a; +WITH a AS (SELECT 1 FROM c.db.z AS z) SELECT 1 FROM a; + +SELECT (SELECT y.c FROM y AS y) FROM x; +SELECT (SELECT y.c FROM c.db.y AS y) FROM c.db.x AS x; diff --git a/tests/fixtures/optimizer/quote_identities.sql b/tests/fixtures/optimizer/quote_identities.sql new file mode 100644 index 0000000..407b7f6 --- /dev/null +++ b/tests/fixtures/optimizer/quote_identities.sql @@ -0,0 +1,8 @@ +SELECT a FROM x; +SELECT "a" FROM "x"; + +SELECT "a" FROM "x"; +SELECT "a" FROM "x"; + +SELECT x.a AS a FROM db.x; +SELECT "x"."a" AS "a" FROM "db"."x"; diff --git a/tests/fixtures/optimizer/simplify.sql b/tests/fixtures/optimizer/simplify.sql new file mode 100644 index 0000000..d7217cf --- /dev/null +++ b/tests/fixtures/optimizer/simplify.sql @@ -0,0 +1,350 @@ +-------------------------------------- +-- Conditions +-------------------------------------- +x AND x; +x; + +y OR y; +y; + +x AND NOT x; +FALSE; + +x OR NOT x; +TRUE; + +1 AND TRUE; +TRUE; + +TRUE AND TRUE; +TRUE; + +1 AND TRUE AND 1 AND 1; +TRUE; + +TRUE AND FALSE; +FALSE; + +FALSE AND FALSE; +FALSE; + +FALSE AND TRUE AND TRUE; +FALSE; + +x > y OR FALSE; +x > y; + +FALSE OR x = y; +x = y; + +1 = 1; +TRUE; + +1.0 = 1; +TRUE; + +'x' = 'y'; +FALSE; + +'x' = 'x'; +TRUE; + +NULL AND TRUE; +NULL; + +NULL AND NULL; +NULL; + +NULL OR TRUE; +TRUE; + +NULL OR NULL; +NULL; + +FALSE OR NULL; +NULL; + +NOT TRUE; +FALSE; + +NOT FALSE; +TRUE; + +NULL = NULL; +NULL; + +NOT (NOT TRUE); +TRUE; + +a AND (b OR b); +a AND b; + +a AND (b AND b); +a AND b; + +-------------------------------------- +-- Absorption +-------------------------------------- +(A OR B) AND (C OR NOT A); +(A OR B) AND (C OR NOT A); + +A AND (A OR B); +A; + +A AND D AND E AND (B OR A); +A AND D AND E; + +D AND A AND E AND (B OR A); +A AND D AND E; + +(A OR B) AND A; +A; + +C AND D AND (A OR B) AND E AND F AND A; +A AND C AND D AND E AND F; + +A OR (A AND B); +A; + +(A AND B) OR A; +A; + +A AND (NOT A OR B); +A AND B; + +(NOT A OR B) AND A; +A AND B; + +A OR (NOT A AND B); +A OR B; + +(A OR C) AND ((A OR C) OR B); +A OR C; + +(A OR C) AND (A OR B OR C); +A OR C; + +-------------------------------------- +-- Elimination +-------------------------------------- +(A AND B) OR (A AND NOT B); +A; + +(A AND B) OR (NOT A AND B); +B; + +(A AND NOT B) OR (A AND B); +A; + +(NOT A AND B) OR (A AND B); +B; + +(A OR B) AND (A OR NOT B); +A; + +(A OR B) AND (NOT A OR B); +B; + +(A OR NOT B) AND (A OR B); +A; + +(NOT A OR B) AND (A OR B); +B; + +(NOT A OR NOT B) AND (NOT A OR B); +NOT A; + +(NOT A OR NOT B) AND (NOT A OR NOT NOT B); +NOT A; + +E OR (A AND B) OR C OR D OR (A AND NOT B); +A OR C OR D OR E; + +-------------------------------------- +-- Associativity +-------------------------------------- +(A AND B) AND C; +A AND B AND C; + +A AND (B AND C); +A AND B AND C; + +(A OR B) OR C; +A OR B OR C; + +A OR (B OR C); +A OR B OR C; + +((A AND B) AND C) AND D; +A AND B AND C AND D; + +(((((A) AND B)) AND C)) AND D; +A AND B AND C AND D; + +-------------------------------------- +-- Comparison and Pruning +-------------------------------------- +A AND D AND B AND E AND F AND G AND E AND A; +A AND B AND D AND E AND F AND G; + +A AND NOT B AND C AND B; +FALSE; + +(a AND b AND c AND d) AND (d AND c AND b AND a); +a AND b AND c AND d; + +(c AND (a AND b)) AND ((b AND a) AND c); +a AND b AND c; + +(A AND B AND C) OR (C AND B AND A); +A AND B AND C; + +-------------------------------------- +-- Where removal +-------------------------------------- +SELECT x WHERE TRUE; +SELECT x; + +-------------------------------------- +-- Parenthesis removal +-------------------------------------- +(TRUE); +TRUE; + +(FALSE); +FALSE; + +(FALSE OR TRUE); +TRUE; + +TRUE OR (((FALSE) OR (TRUE)) OR FALSE); +TRUE; + +(NOT FALSE) AND (NOT TRUE); +FALSE; + +((NOT FALSE) AND (x = x)) AND (TRUE OR 1 <> 3); +TRUE; + +((NOT FALSE) AND (x = x)) AND (FALSE OR 1 <> 2); +TRUE; + +(('a' = 'a') AND TRUE and NOT FALSE); +TRUE; + +-------------------------------------- +-- Literals +-------------------------------------- +1 + 1; +2; + +0.06 + 0.01; +0.07; + +0.06 + 1; +1.06; + +1.2E+1 + 15E-3; +12.015; + +1.2E1 + 15E-3; +12.015; + +1 - 2; +-1; + +-1 + 3; +2; + +-(-1); +1; + +0.06 - 0.01; +0.05; + +3 * 4; +12; + +3.0 * 9; +27.0; + +0.03 * 0.73; +0.0219; + +1 / 3; +0; + +20.0 / 6; +3.333333333333333333333333333; + +10 / 5; +2; + +(1.0 * 3) * 4 - 2 * (5 / 2); +8.0; + +6 - 2 + 4 * 2 + a; +12 + a; + +a + 1 + 1 + 2; +a + 4; + +a + (1 + 1) + (10); +a + 12; + +5 + 4 * 3; +17; + +1 < 2; +TRUE; + +2 <= 2; +TRUE; + +2 >= 2; +TRUE; + +2 > 1; +TRUE; + +2 > 2.5; +FALSE; + +3 > 2.5; +TRUE; + +1 > NULL; +NULL; + +1 <= NULL; +NULL; + +1 IS NULL; +FALSE; + +NULL IS NULL; +TRUE; + +NULL IS NOT NULL; +FALSE; + +1 IS NOT NULL; +TRUE; + +date '1998-12-01' - interval '90' day; +CAST('1998-09-02' AS DATE); + +date '1998-12-01' + interval '1' week; +CAST('1998-12-08' AS DATE); + +interval '1' year + date '1998-01-01'; +CAST('1999-01-01' AS DATE); + +interval '1' year + date '1998-01-01' + 3 * 7 * 4; +CAST('1999-01-01' AS DATE) + 84; + +date '1998-12-01' - interval '90' foo; +CAST('1998-12-01' AS DATE) - INTERVAL '90' foo; + +date '1998-12-01' + interval '90' foo; +CAST('1998-12-01' AS DATE) + INTERVAL '90' foo; diff --git a/tests/fixtures/optimizer/tpc-h/customer.csv.gz b/tests/fixtures/optimizer/tpc-h/customer.csv.gz new file mode 100644 index 0000000..e0d149c Binary files /dev/null and b/tests/fixtures/optimizer/tpc-h/customer.csv.gz differ diff --git a/tests/fixtures/optimizer/tpc-h/lineitem.csv.gz b/tests/fixtures/optimizer/tpc-h/lineitem.csv.gz new file mode 100644 index 0000000..08e40d8 Binary files /dev/null and b/tests/fixtures/optimizer/tpc-h/lineitem.csv.gz differ diff --git a/tests/fixtures/optimizer/tpc-h/nation.csv.gz b/tests/fixtures/optimizer/tpc-h/nation.csv.gz new file mode 100644 index 0000000..d5bf6e3 Binary files /dev/null and b/tests/fixtures/optimizer/tpc-h/nation.csv.gz differ diff --git a/tests/fixtures/optimizer/tpc-h/orders.csv.gz b/tests/fixtures/optimizer/tpc-h/orders.csv.gz new file mode 100644 index 0000000..9b572bc Binary files /dev/null and b/tests/fixtures/optimizer/tpc-h/orders.csv.gz differ diff --git a/tests/fixtures/optimizer/tpc-h/part.csv.gz b/tests/fixtures/optimizer/tpc-h/part.csv.gz new file mode 100644 index 0000000..2dfdaa5 Binary files /dev/null and b/tests/fixtures/optimizer/tpc-h/part.csv.gz differ diff --git a/tests/fixtures/optimizer/tpc-h/partsupp.csv.gz b/tests/fixtures/optimizer/tpc-h/partsupp.csv.gz new file mode 100644 index 0000000..de9a2ce Binary files /dev/null and b/tests/fixtures/optimizer/tpc-h/partsupp.csv.gz differ diff --git a/tests/fixtures/optimizer/tpc-h/region.csv.gz b/tests/fixtures/optimizer/tpc-h/region.csv.gz new file mode 100644 index 0000000..3dbd31a Binary files /dev/null and b/tests/fixtures/optimizer/tpc-h/region.csv.gz differ diff --git a/tests/fixtures/optimizer/tpc-h/supplier.csv.gz b/tests/fixtures/optimizer/tpc-h/supplier.csv.gz new file mode 100644 index 0000000..8dad82a Binary files /dev/null and b/tests/fixtures/optimizer/tpc-h/supplier.csv.gz differ diff --git a/tests/fixtures/optimizer/tpc-h/tpc-h.sql b/tests/fixtures/optimizer/tpc-h/tpc-h.sql new file mode 100644 index 0000000..482e231 --- /dev/null +++ b/tests/fixtures/optimizer/tpc-h/tpc-h.sql @@ -0,0 +1,1810 @@ +-------------------------------------- +-- TPC-H 1 +-------------------------------------- +select + l_returnflag, + l_linestatus, + sum(l_quantity) as sum_qty, + sum(l_extendedprice) as sum_base_price, + sum(l_extendedprice * (1 - l_discount)) as sum_disc_price, + sum(l_extendedprice * (1 - l_discount) * (1 + l_tax)) as sum_charge, + avg(l_quantity) as avg_qty, + avg(l_extendedprice) as avg_price, + avg(l_discount) as avg_disc, + count(*) as count_order +from + lineitem +where + CAST(l_shipdate AS DATE) <= date '1998-12-01' - interval '90' day +group by + l_returnflag, + l_linestatus +order by + l_returnflag, + l_linestatus; +SELECT + "lineitem"."l_returnflag" AS "l_returnflag", + "lineitem"."l_linestatus" AS "l_linestatus", + SUM("lineitem"."l_quantity") AS "sum_qty", + SUM("lineitem"."l_extendedprice") AS "sum_base_price", + SUM("lineitem"."l_extendedprice" * ( + 1 - "lineitem"."l_discount" + )) AS "sum_disc_price", + SUM("lineitem"."l_extendedprice" * ( + 1 - "lineitem"."l_discount" + ) * ( + 1 + "lineitem"."l_tax" + )) AS "sum_charge", + AVG("lineitem"."l_quantity") AS "avg_qty", + AVG("lineitem"."l_extendedprice") AS "avg_price", + AVG("lineitem"."l_discount") AS "avg_disc", + COUNT(*) AS "count_order" +FROM "lineitem" AS "lineitem" +WHERE + CAST("lineitem"."l_shipdate" AS DATE) <= CAST('1998-09-02' AS DATE) +GROUP BY + "lineitem"."l_returnflag", + "lineitem"."l_linestatus" +ORDER BY + "l_returnflag", + "l_linestatus"; + +-------------------------------------- +-- TPC-H 2 +-------------------------------------- +select + s_acctbal, + s_name, + n_name, + p_partkey, + p_mfgr, + s_address, + s_phone, + s_comment +from + part, + supplier, + partsupp, + nation, + region +where + p_partkey = ps_partkey + and s_suppkey = ps_suppkey + and p_size = 15 + and p_type like '%BRASS' + and s_nationkey = n_nationkey + and n_regionkey = r_regionkey + and r_name = 'EUROPE' + and ps_supplycost = ( + select + min(ps_supplycost) + from + partsupp, + supplier, + nation, + region + where + p_partkey = ps_partkey + and s_suppkey = ps_suppkey + and s_nationkey = n_nationkey + and n_regionkey = r_regionkey + and r_name = 'EUROPE' + ) +order by + s_acctbal desc, + n_name, + s_name, + p_partkey +limit + 100; +WITH "_e_0" AS ( + SELECT + "partsupp"."ps_partkey" AS "ps_partkey", + "partsupp"."ps_suppkey" AS "ps_suppkey", + "partsupp"."ps_supplycost" AS "ps_supplycost" + FROM "partsupp" AS "partsupp" +), "_e_1" AS ( + SELECT + "region"."r_regionkey" AS "r_regionkey", + "region"."r_name" AS "r_name" + FROM "region" AS "region" + WHERE + "region"."r_name" = 'EUROPE' +) +SELECT + "supplier"."s_acctbal" AS "s_acctbal", + "supplier"."s_name" AS "s_name", + "nation"."n_name" AS "n_name", + "part"."p_partkey" AS "p_partkey", + "part"."p_mfgr" AS "p_mfgr", + "supplier"."s_address" AS "s_address", + "supplier"."s_phone" AS "s_phone", + "supplier"."s_comment" AS "s_comment" +FROM ( + SELECT + "part"."p_partkey" AS "p_partkey", + "part"."p_mfgr" AS "p_mfgr", + "part"."p_type" AS "p_type", + "part"."p_size" AS "p_size" + FROM "part" AS "part" + WHERE + "part"."p_size" = 15 + AND "part"."p_type" LIKE '%BRASS' +) AS "part" +LEFT JOIN ( + SELECT + MIN("partsupp"."ps_supplycost") AS "_col_0", + "partsupp"."ps_partkey" AS "_u_1" + FROM "_e_0" AS "partsupp" + CROSS JOIN "_e_1" AS "region" + JOIN ( + SELECT + "nation"."n_nationkey" AS "n_nationkey", + "nation"."n_regionkey" AS "n_regionkey" + FROM "nation" AS "nation" + ) AS "nation" + ON "nation"."n_regionkey" = "region"."r_regionkey" + JOIN ( + SELECT + "supplier"."s_suppkey" AS "s_suppkey", + "supplier"."s_nationkey" AS "s_nationkey" + FROM "supplier" AS "supplier" + ) AS "supplier" + ON "supplier"."s_nationkey" = "nation"."n_nationkey" + AND "supplier"."s_suppkey" = "partsupp"."ps_suppkey" + GROUP BY + "partsupp"."ps_partkey" +) AS "_u_0" + ON "part"."p_partkey" = "_u_0"."_u_1" +CROSS JOIN "_e_1" AS "region" +JOIN ( + SELECT + "nation"."n_nationkey" AS "n_nationkey", + "nation"."n_name" AS "n_name", + "nation"."n_regionkey" AS "n_regionkey" + FROM "nation" AS "nation" +) AS "nation" + ON "nation"."n_regionkey" = "region"."r_regionkey" +JOIN "_e_0" AS "partsupp" + ON "part"."p_partkey" = "partsupp"."ps_partkey" +JOIN ( + SELECT + "supplier"."s_suppkey" AS "s_suppkey", + "supplier"."s_name" AS "s_name", + "supplier"."s_address" AS "s_address", + "supplier"."s_nationkey" AS "s_nationkey", + "supplier"."s_phone" AS "s_phone", + "supplier"."s_acctbal" AS "s_acctbal", + "supplier"."s_comment" AS "s_comment" + FROM "supplier" AS "supplier" +) AS "supplier" + ON "supplier"."s_nationkey" = "nation"."n_nationkey" + AND "supplier"."s_suppkey" = "partsupp"."ps_suppkey" +WHERE + "partsupp"."ps_supplycost" = "_u_0"."_col_0" + AND NOT "_u_0"."_u_1" IS NULL +ORDER BY + "s_acctbal" DESC, + "n_name", + "s_name", + "p_partkey" +LIMIT 100; + +-------------------------------------- +-- TPC-H 3 +-------------------------------------- +select + l_orderkey, + sum(l_extendedprice * (1 - l_discount)) as revenue, + CAST(o_orderdate AS STRING) AS o_orderdate, + o_shippriority +from + customer, + orders, + lineitem +where + c_mktsegment = 'BUILDING' + and c_custkey = o_custkey + and l_orderkey = o_orderkey + and o_orderdate < '1995-03-15' + and l_shipdate > '1995-03-15' +group by + l_orderkey, + o_orderdate, + o_shippriority +order by + revenue desc, + o_orderdate +limit + 10; +SELECT + "lineitem"."l_orderkey" AS "l_orderkey", + SUM("lineitem"."l_extendedprice" * ( + 1 - "lineitem"."l_discount" + )) AS "revenue", + CAST("orders"."o_orderdate" AS TEXT) AS "o_orderdate", + "orders"."o_shippriority" AS "o_shippriority" +FROM ( + SELECT + "customer"."c_custkey" AS "c_custkey", + "customer"."c_mktsegment" AS "c_mktsegment" + FROM "customer" AS "customer" + WHERE + "customer"."c_mktsegment" = 'BUILDING' +) AS "customer" +JOIN ( + SELECT + "orders"."o_orderkey" AS "o_orderkey", + "orders"."o_custkey" AS "o_custkey", + "orders"."o_orderdate" AS "o_orderdate", + "orders"."o_shippriority" AS "o_shippriority" + FROM "orders" AS "orders" + WHERE + "orders"."o_orderdate" < '1995-03-15' +) AS "orders" + ON "customer"."c_custkey" = "orders"."o_custkey" +JOIN ( + SELECT + "lineitem"."l_orderkey" AS "l_orderkey", + "lineitem"."l_extendedprice" AS "l_extendedprice", + "lineitem"."l_discount" AS "l_discount", + "lineitem"."l_shipdate" AS "l_shipdate" + FROM "lineitem" AS "lineitem" + WHERE + "lineitem"."l_shipdate" > '1995-03-15' +) AS "lineitem" + ON "lineitem"."l_orderkey" = "orders"."o_orderkey" +GROUP BY + "lineitem"."l_orderkey", + "orders"."o_orderdate", + "orders"."o_shippriority" +ORDER BY + "revenue" DESC, + "o_orderdate" +LIMIT 10; + +-------------------------------------- +-- TPC-H 4 +-------------------------------------- +select + o_orderpriority, + count(*) as order_count +from + orders +where + o_orderdate >= date '1993-07-01' + and o_orderdate < date '1993-07-01' + interval '3' month + and exists ( + select + * + from + lineitem + where + l_orderkey = o_orderkey + and l_commitdate < l_receiptdate + ) +group by + o_orderpriority +order by + o_orderpriority; +SELECT + "orders"."o_orderpriority" AS "o_orderpriority", + COUNT(*) AS "order_count" +FROM "orders" AS "orders" +LEFT JOIN ( + SELECT + "lineitem"."l_orderkey" AS "l_orderkey" + FROM "lineitem" AS "lineitem" + WHERE + "lineitem"."l_commitdate" < "lineitem"."l_receiptdate" + GROUP BY + "lineitem"."l_orderkey" +) AS "_u_0" + ON "_u_0"."l_orderkey" = "orders"."o_orderkey" +WHERE + "orders"."o_orderdate" < CAST('1993-10-01' AS DATE) + AND "orders"."o_orderdate" >= CAST('1993-07-01' AS DATE) + AND NOT "_u_0"."l_orderkey" IS NULL +GROUP BY + "orders"."o_orderpriority" +ORDER BY + "o_orderpriority"; + +-------------------------------------- +-- TPC-H 5 +-------------------------------------- +select + n_name, + sum(l_extendedprice * (1 - l_discount)) as revenue +from + customer, + orders, + lineitem, + supplier, + nation, + region +where + c_custkey = o_custkey + and l_orderkey = o_orderkey + and l_suppkey = s_suppkey + and c_nationkey = s_nationkey + and s_nationkey = n_nationkey + and n_regionkey = r_regionkey + and r_name = 'ASIA' + and o_orderdate >= date '1994-01-01' + and o_orderdate < date '1994-01-01' + interval '1' year +group by + n_name +order by + revenue desc; +SELECT + "nation"."n_name" AS "n_name", + SUM("lineitem"."l_extendedprice" * ( + 1 - "lineitem"."l_discount" + )) AS "revenue" +FROM ( + SELECT + "customer"."c_custkey" AS "c_custkey", + "customer"."c_nationkey" AS "c_nationkey" + FROM "customer" AS "customer" +) AS "customer" +JOIN ( + SELECT + "orders"."o_orderkey" AS "o_orderkey", + "orders"."o_custkey" AS "o_custkey", + "orders"."o_orderdate" AS "o_orderdate" + FROM "orders" AS "orders" + WHERE + "orders"."o_orderdate" < CAST('1995-01-01' AS DATE) + AND "orders"."o_orderdate" >= CAST('1994-01-01' AS DATE) +) AS "orders" + ON "customer"."c_custkey" = "orders"."o_custkey" +CROSS JOIN ( + SELECT + "region"."r_regionkey" AS "r_regionkey", + "region"."r_name" AS "r_name" + FROM "region" AS "region" + WHERE + "region"."r_name" = 'ASIA' +) AS "region" +JOIN ( + SELECT + "nation"."n_nationkey" AS "n_nationkey", + "nation"."n_name" AS "n_name", + "nation"."n_regionkey" AS "n_regionkey" + FROM "nation" AS "nation" +) AS "nation" + ON "nation"."n_regionkey" = "region"."r_regionkey" +JOIN ( + SELECT + "supplier"."s_suppkey" AS "s_suppkey", + "supplier"."s_nationkey" AS "s_nationkey" + FROM "supplier" AS "supplier" +) AS "supplier" + ON "customer"."c_nationkey" = "supplier"."s_nationkey" + AND "supplier"."s_nationkey" = "nation"."n_nationkey" +JOIN ( + SELECT + "lineitem"."l_orderkey" AS "l_orderkey", + "lineitem"."l_suppkey" AS "l_suppkey", + "lineitem"."l_extendedprice" AS "l_extendedprice", + "lineitem"."l_discount" AS "l_discount" + FROM "lineitem" AS "lineitem" +) AS "lineitem" + ON "lineitem"."l_orderkey" = "orders"."o_orderkey" + AND "lineitem"."l_suppkey" = "supplier"."s_suppkey" +GROUP BY + "nation"."n_name" +ORDER BY + "revenue" DESC; + +-------------------------------------- +-- TPC-H 6 +-------------------------------------- +select + sum(l_extendedprice * l_discount) as revenue +from + lineitem +where + l_shipdate >= date '1994-01-01' + and l_shipdate < date '1994-01-01' + interval '1' year + and l_discount between 0.06 - 0.01 and 0.06 + 0.01 + and l_quantity < 24; +SELECT + SUM("lineitem"."l_extendedprice" * "lineitem"."l_discount") AS "revenue" +FROM "lineitem" AS "lineitem" +WHERE + "lineitem"."l_discount" BETWEEN 0.05 AND 0.07 + AND "lineitem"."l_quantity" < 24 + AND "lineitem"."l_shipdate" < CAST('1995-01-01' AS DATE) + AND "lineitem"."l_shipdate" >= CAST('1994-01-01' AS DATE); + +-------------------------------------- +-- TPC-H 7 +-------------------------------------- +select + supp_nation, + cust_nation, + l_year, + sum(volume) as revenue +from + ( + select + n1.n_name as supp_nation, + n2.n_name as cust_nation, + extract(year from l_shipdate) as l_year, + l_extendedprice * (1 - l_discount) as volume + from + supplier, + lineitem, + orders, + customer, + nation n1, + nation n2 + where + s_suppkey = l_suppkey + and o_orderkey = l_orderkey + and c_custkey = o_custkey + and s_nationkey = n1.n_nationkey + and c_nationkey = n2.n_nationkey + and ( + (n1.n_name = 'FRANCE' and n2.n_name = 'GERMANY') + or (n1.n_name = 'GERMANY' and n2.n_name = 'FRANCE') + ) + and l_shipdate between date '1995-01-01' and date '1996-12-31' + ) as shipping +group by + supp_nation, + cust_nation, + l_year +order by + supp_nation, + cust_nation, + l_year; +WITH "_e_0" AS ( + SELECT + "nation"."n_nationkey" AS "n_nationkey", + "nation"."n_name" AS "n_name" + FROM "nation" AS "nation" + WHERE + "nation"."n_name" = 'FRANCE' + OR "nation"."n_name" = 'GERMANY' +) +SELECT + "shipping"."supp_nation" AS "supp_nation", + "shipping"."cust_nation" AS "cust_nation", + "shipping"."l_year" AS "l_year", + SUM("shipping"."volume") AS "revenue" +FROM ( + SELECT + "n1"."n_name" AS "supp_nation", + "n2"."n_name" AS "cust_nation", + EXTRACT(year FROM "lineitem"."l_shipdate") AS "l_year", + "lineitem"."l_extendedprice" * ( + 1 - "lineitem"."l_discount" + ) AS "volume" + FROM ( + SELECT + "supplier"."s_suppkey" AS "s_suppkey", + "supplier"."s_nationkey" AS "s_nationkey" + FROM "supplier" AS "supplier" + ) AS "supplier" + JOIN ( + SELECT + "lineitem"."l_orderkey" AS "l_orderkey", + "lineitem"."l_suppkey" AS "l_suppkey", + "lineitem"."l_extendedprice" AS "l_extendedprice", + "lineitem"."l_discount" AS "l_discount", + "lineitem"."l_shipdate" AS "l_shipdate" + FROM "lineitem" AS "lineitem" + WHERE + "lineitem"."l_shipdate" BETWEEN CAST('1995-01-01' AS DATE) AND CAST('1996-12-31' AS DATE) + ) AS "lineitem" + ON "supplier"."s_suppkey" = "lineitem"."l_suppkey" + JOIN ( + SELECT + "orders"."o_orderkey" AS "o_orderkey", + "orders"."o_custkey" AS "o_custkey" + FROM "orders" AS "orders" + ) AS "orders" + ON "orders"."o_orderkey" = "lineitem"."l_orderkey" + JOIN ( + SELECT + "customer"."c_custkey" AS "c_custkey", + "customer"."c_nationkey" AS "c_nationkey" + FROM "customer" AS "customer" + ) AS "customer" + ON "customer"."c_custkey" = "orders"."o_custkey" + JOIN "_e_0" AS "n1" + ON "supplier"."s_nationkey" = "n1"."n_nationkey" + JOIN "_e_0" AS "n2" + ON "customer"."c_nationkey" = "n2"."n_nationkey" + AND ( + "n1"."n_name" = 'FRANCE' + OR "n2"."n_name" = 'FRANCE' + ) + AND ( + "n1"."n_name" = 'GERMANY' + OR "n2"."n_name" = 'GERMANY' + ) +) AS "shipping" +GROUP BY + "shipping"."supp_nation", + "shipping"."cust_nation", + "shipping"."l_year" +ORDER BY + "supp_nation", + "cust_nation", + "l_year"; + +-------------------------------------- +-- TPC-H 8 +-------------------------------------- +select + o_year, + sum(case + when nation = 'BRAZIL' then volume + else 0 + end) / sum(volume) as mkt_share +from + ( + select + extract(year from o_orderdate) as o_year, + l_extendedprice * (1 - l_discount) as volume, + n2.n_name as nation + from + part, + supplier, + lineitem, + orders, + customer, + nation n1, + nation n2, + region + where + p_partkey = l_partkey + and s_suppkey = l_suppkey + and l_orderkey = o_orderkey + and o_custkey = c_custkey + and c_nationkey = n1.n_nationkey + and n1.n_regionkey = r_regionkey + and r_name = 'AMERICA' + and s_nationkey = n2.n_nationkey + and o_orderdate between date '1995-01-01' and date '1996-12-31' + and p_type = 'ECONOMY ANODIZED STEEL' + ) as all_nations +group by + o_year +order by + o_year; +SELECT + "all_nations"."o_year" AS "o_year", + SUM(CASE + WHEN "all_nations"."nation" = 'BRAZIL' + THEN "all_nations"."volume" + ELSE 0 + END) / SUM("all_nations"."volume") AS "mkt_share" +FROM ( + SELECT + EXTRACT(year FROM "orders"."o_orderdate") AS "o_year", + "lineitem"."l_extendedprice" * ( + 1 - "lineitem"."l_discount" + ) AS "volume", + "n2"."n_name" AS "nation" + FROM ( + SELECT + "part"."p_partkey" AS "p_partkey", + "part"."p_type" AS "p_type" + FROM "part" AS "part" + WHERE + "part"."p_type" = 'ECONOMY ANODIZED STEEL' + ) AS "part" + CROSS JOIN ( + SELECT + "region"."r_regionkey" AS "r_regionkey", + "region"."r_name" AS "r_name" + FROM "region" AS "region" + WHERE + "region"."r_name" = 'AMERICA' + ) AS "region" + JOIN ( + SELECT + "nation"."n_nationkey" AS "n_nationkey", + "nation"."n_regionkey" AS "n_regionkey" + FROM "nation" AS "nation" + ) AS "n1" + ON "n1"."n_regionkey" = "region"."r_regionkey" + JOIN ( + SELECT + "customer"."c_custkey" AS "c_custkey", + "customer"."c_nationkey" AS "c_nationkey" + FROM "customer" AS "customer" + ) AS "customer" + ON "customer"."c_nationkey" = "n1"."n_nationkey" + JOIN ( + SELECT + "orders"."o_orderkey" AS "o_orderkey", + "orders"."o_custkey" AS "o_custkey", + "orders"."o_orderdate" AS "o_orderdate" + FROM "orders" AS "orders" + WHERE + "orders"."o_orderdate" BETWEEN CAST('1995-01-01' AS DATE) AND CAST('1996-12-31' AS DATE) + ) AS "orders" + ON "orders"."o_custkey" = "customer"."c_custkey" + JOIN ( + SELECT + "lineitem"."l_orderkey" AS "l_orderkey", + "lineitem"."l_partkey" AS "l_partkey", + "lineitem"."l_suppkey" AS "l_suppkey", + "lineitem"."l_extendedprice" AS "l_extendedprice", + "lineitem"."l_discount" AS "l_discount" + FROM "lineitem" AS "lineitem" + ) AS "lineitem" + ON "lineitem"."l_orderkey" = "orders"."o_orderkey" + AND "part"."p_partkey" = "lineitem"."l_partkey" + JOIN ( + SELECT + "supplier"."s_suppkey" AS "s_suppkey", + "supplier"."s_nationkey" AS "s_nationkey" + FROM "supplier" AS "supplier" + ) AS "supplier" + ON "supplier"."s_suppkey" = "lineitem"."l_suppkey" + JOIN ( + SELECT + "nation"."n_nationkey" AS "n_nationkey", + "nation"."n_name" AS "n_name" + FROM "nation" AS "nation" + ) AS "n2" + ON "supplier"."s_nationkey" = "n2"."n_nationkey" +) AS "all_nations" +GROUP BY + "all_nations"."o_year" +ORDER BY + "o_year"; + +-------------------------------------- +-- TPC-H 9 +-------------------------------------- +select + nation, + o_year, + sum(amount) as sum_profit +from + ( + select + n_name as nation, + extract(year from o_orderdate) as o_year, + l_extendedprice * (1 - l_discount) - ps_supplycost * l_quantity as amount + from + part, + supplier, + lineitem, + partsupp, + orders, + nation + where + s_suppkey = l_suppkey + and ps_suppkey = l_suppkey + and ps_partkey = l_partkey + and p_partkey = l_partkey + and o_orderkey = l_orderkey + and s_nationkey = n_nationkey + and p_name like '%green%' + ) as profit +group by + nation, + o_year +order by + nation, + o_year desc; +SELECT + "profit"."nation" AS "nation", + "profit"."o_year" AS "o_year", + SUM("profit"."amount") AS "sum_profit" +FROM ( + SELECT + "nation"."n_name" AS "nation", + EXTRACT(year FROM "orders"."o_orderdate") AS "o_year", + "lineitem"."l_extendedprice" * ( + 1 - "lineitem"."l_discount" + ) - "partsupp"."ps_supplycost" * "lineitem"."l_quantity" AS "amount" + FROM ( + SELECT + "part"."p_partkey" AS "p_partkey", + "part"."p_name" AS "p_name" + FROM "part" AS "part" + WHERE + "part"."p_name" LIKE '%green%' + ) AS "part" + JOIN ( + SELECT + "lineitem"."l_orderkey" AS "l_orderkey", + "lineitem"."l_partkey" AS "l_partkey", + "lineitem"."l_suppkey" AS "l_suppkey", + "lineitem"."l_quantity" AS "l_quantity", + "lineitem"."l_extendedprice" AS "l_extendedprice", + "lineitem"."l_discount" AS "l_discount" + FROM "lineitem" AS "lineitem" + ) AS "lineitem" + ON "part"."p_partkey" = "lineitem"."l_partkey" + JOIN ( + SELECT + "supplier"."s_suppkey" AS "s_suppkey", + "supplier"."s_nationkey" AS "s_nationkey" + FROM "supplier" AS "supplier" + ) AS "supplier" + ON "supplier"."s_suppkey" = "lineitem"."l_suppkey" + JOIN ( + SELECT + "partsupp"."ps_partkey" AS "ps_partkey", + "partsupp"."ps_suppkey" AS "ps_suppkey", + "partsupp"."ps_supplycost" AS "ps_supplycost" + FROM "partsupp" AS "partsupp" + ) AS "partsupp" + ON "partsupp"."ps_partkey" = "lineitem"."l_partkey" + AND "partsupp"."ps_suppkey" = "lineitem"."l_suppkey" + JOIN ( + SELECT + "orders"."o_orderkey" AS "o_orderkey", + "orders"."o_orderdate" AS "o_orderdate" + FROM "orders" AS "orders" + ) AS "orders" + ON "orders"."o_orderkey" = "lineitem"."l_orderkey" + JOIN ( + SELECT + "nation"."n_nationkey" AS "n_nationkey", + "nation"."n_name" AS "n_name" + FROM "nation" AS "nation" + ) AS "nation" + ON "supplier"."s_nationkey" = "nation"."n_nationkey" +) AS "profit" +GROUP BY + "profit"."nation", + "profit"."o_year" +ORDER BY + "nation", + "o_year" DESC; + +-------------------------------------- +-- TPC-H 10 +-------------------------------------- +select + c_custkey, + c_name, + sum(l_extendedprice * (1 - l_discount)) as revenue, + c_acctbal, + n_name, + c_address, + c_phone, + c_comment +from + customer, + orders, + lineitem, + nation +where + c_custkey = o_custkey + and l_orderkey = o_orderkey + and o_orderdate >= date '1993-10-01' + and o_orderdate < date '1993-10-01' + interval '3' month + and l_returnflag = 'R' + and c_nationkey = n_nationkey +group by + c_custkey, + c_name, + c_acctbal, + c_phone, + n_name, + c_address, + c_comment +order by + revenue desc +limit + 20; +SELECT + "customer"."c_custkey" AS "c_custkey", + "customer"."c_name" AS "c_name", + SUM("lineitem"."l_extendedprice" * ( + 1 - "lineitem"."l_discount" + )) AS "revenue", + "customer"."c_acctbal" AS "c_acctbal", + "nation"."n_name" AS "n_name", + "customer"."c_address" AS "c_address", + "customer"."c_phone" AS "c_phone", + "customer"."c_comment" AS "c_comment" +FROM ( + SELECT + "customer"."c_custkey" AS "c_custkey", + "customer"."c_name" AS "c_name", + "customer"."c_address" AS "c_address", + "customer"."c_nationkey" AS "c_nationkey", + "customer"."c_phone" AS "c_phone", + "customer"."c_acctbal" AS "c_acctbal", + "customer"."c_comment" AS "c_comment" + FROM "customer" AS "customer" +) AS "customer" +JOIN ( + SELECT + "orders"."o_orderkey" AS "o_orderkey", + "orders"."o_custkey" AS "o_custkey", + "orders"."o_orderdate" AS "o_orderdate" + FROM "orders" AS "orders" + WHERE + "orders"."o_orderdate" < CAST('1994-01-01' AS DATE) + AND "orders"."o_orderdate" >= CAST('1993-10-01' AS DATE) +) AS "orders" + ON "customer"."c_custkey" = "orders"."o_custkey" +JOIN ( + SELECT + "lineitem"."l_orderkey" AS "l_orderkey", + "lineitem"."l_extendedprice" AS "l_extendedprice", + "lineitem"."l_discount" AS "l_discount", + "lineitem"."l_returnflag" AS "l_returnflag" + FROM "lineitem" AS "lineitem" + WHERE + "lineitem"."l_returnflag" = 'R' +) AS "lineitem" + ON "lineitem"."l_orderkey" = "orders"."o_orderkey" +JOIN ( + SELECT + "nation"."n_nationkey" AS "n_nationkey", + "nation"."n_name" AS "n_name" + FROM "nation" AS "nation" +) AS "nation" + ON "customer"."c_nationkey" = "nation"."n_nationkey" +GROUP BY + "customer"."c_custkey", + "customer"."c_name", + "customer"."c_acctbal", + "customer"."c_phone", + "nation"."n_name", + "customer"."c_address", + "customer"."c_comment" +ORDER BY + "revenue" DESC +LIMIT 20; + +-------------------------------------- +-- TPC-H 11 +-------------------------------------- +select + ps_partkey, + sum(ps_supplycost * ps_availqty) as value +from + partsupp, + supplier, + nation +where + ps_suppkey = s_suppkey + and s_nationkey = n_nationkey + and n_name = 'GERMANY' +group by + ps_partkey having + sum(ps_supplycost * ps_availqty) > ( + select + sum(ps_supplycost * ps_availqty) * 0.0001 + from + partsupp, + supplier, + nation + where + ps_suppkey = s_suppkey + and s_nationkey = n_nationkey + and n_name = 'GERMANY' + ) +order by + value desc; +WITH "_e_0" AS ( + SELECT + "supplier"."s_suppkey" AS "s_suppkey", + "supplier"."s_nationkey" AS "s_nationkey" + FROM "supplier" AS "supplier" +), "_e_1" AS ( + SELECT + "nation"."n_nationkey" AS "n_nationkey", + "nation"."n_name" AS "n_name" + FROM "nation" AS "nation" + WHERE + "nation"."n_name" = 'GERMANY' +) +SELECT + "partsupp"."ps_partkey" AS "ps_partkey", + SUM("partsupp"."ps_supplycost" * "partsupp"."ps_availqty") AS "value" +FROM ( + SELECT + "partsupp"."ps_partkey" AS "ps_partkey", + "partsupp"."ps_suppkey" AS "ps_suppkey", + "partsupp"."ps_availqty" AS "ps_availqty", + "partsupp"."ps_supplycost" AS "ps_supplycost" + FROM "partsupp" AS "partsupp" +) AS "partsupp" +JOIN "_e_0" AS "supplier" + ON "partsupp"."ps_suppkey" = "supplier"."s_suppkey" +JOIN "_e_1" AS "nation" + ON "supplier"."s_nationkey" = "nation"."n_nationkey" +GROUP BY + "partsupp"."ps_partkey" +HAVING + SUM("partsupp"."ps_supplycost" * "partsupp"."ps_availqty") > ( + SELECT + SUM("partsupp"."ps_supplycost" * "partsupp"."ps_availqty") * 0.0001 AS "_col_0" + FROM ( + SELECT + "partsupp"."ps_suppkey" AS "ps_suppkey", + "partsupp"."ps_availqty" AS "ps_availqty", + "partsupp"."ps_supplycost" AS "ps_supplycost" + FROM "partsupp" AS "partsupp" + ) AS "partsupp" + JOIN "_e_0" AS "supplier" + ON "partsupp"."ps_suppkey" = "supplier"."s_suppkey" + JOIN "_e_1" AS "nation" + ON "supplier"."s_nationkey" = "nation"."n_nationkey" + ) +ORDER BY + "value" DESC; + +-------------------------------------- +-- TPC-H 12 +-------------------------------------- +select + l_shipmode, + sum(case + when o_orderpriority = '1-URGENT' + or o_orderpriority = '2-HIGH' + then 1 + else 0 + end) as high_line_count, + sum(case + when o_orderpriority <> '1-URGENT' + and o_orderpriority <> '2-HIGH' + then 1 + else 0 + end) as low_line_count +from + orders, + lineitem +where + o_orderkey = l_orderkey + and l_shipmode in ('MAIL', 'SHIP') + and l_commitdate < l_receiptdate + and l_shipdate < l_commitdate + and l_receiptdate >= date '1994-01-01' + and l_receiptdate < date '1994-01-01' + interval '1' year +group by + l_shipmode +order by + l_shipmode; +SELECT + "lineitem"."l_shipmode" AS "l_shipmode", + SUM(CASE + WHEN "orders"."o_orderpriority" = '1-URGENT' + OR "orders"."o_orderpriority" = '2-HIGH' + THEN 1 + ELSE 0 + END) AS "high_line_count", + SUM(CASE + WHEN "orders"."o_orderpriority" <> '1-URGENT' + AND "orders"."o_orderpriority" <> '2-HIGH' + THEN 1 + ELSE 0 + END) AS "low_line_count" +FROM ( + SELECT + "orders"."o_orderkey" AS "o_orderkey", + "orders"."o_orderpriority" AS "o_orderpriority" + FROM "orders" AS "orders" +) AS "orders" +JOIN ( + SELECT + "lineitem"."l_orderkey" AS "l_orderkey", + "lineitem"."l_shipdate" AS "l_shipdate", + "lineitem"."l_commitdate" AS "l_commitdate", + "lineitem"."l_receiptdate" AS "l_receiptdate", + "lineitem"."l_shipmode" AS "l_shipmode" + FROM "lineitem" AS "lineitem" + WHERE + "lineitem"."l_commitdate" < "lineitem"."l_receiptdate" + AND "lineitem"."l_receiptdate" < CAST('1995-01-01' AS DATE) + AND "lineitem"."l_receiptdate" >= CAST('1994-01-01' AS DATE) + AND "lineitem"."l_shipdate" < "lineitem"."l_commitdate" + AND "lineitem"."l_shipmode" IN ('MAIL', 'SHIP') +) AS "lineitem" + ON "orders"."o_orderkey" = "lineitem"."l_orderkey" +GROUP BY + "lineitem"."l_shipmode" +ORDER BY + "l_shipmode"; + +-------------------------------------- +-- TPC-H 13 +-------------------------------------- +select + c_count, + count(*) as custdist +from + ( + select + c_custkey, + count(o_orderkey) + from + customer left outer join orders on + c_custkey = o_custkey + and o_comment not like '%special%requests%' + group by + c_custkey + ) as c_orders (c_custkey, c_count) +group by + c_count +order by + custdist desc, + c_count desc; +SELECT + "c_orders"."c_count" AS "c_count", + COUNT(*) AS "custdist" +FROM ( + SELECT + COUNT("orders"."o_orderkey") AS "c_count" + FROM ( + SELECT + "customer"."c_custkey" AS "c_custkey" + FROM "customer" AS "customer" + ) AS "customer" + LEFT JOIN ( + SELECT + "orders"."o_orderkey" AS "o_orderkey", + "orders"."o_custkey" AS "o_custkey", + "orders"."o_comment" AS "o_comment" + FROM "orders" AS "orders" + WHERE + NOT "orders"."o_comment" LIKE '%special%requests%' + ) AS "orders" + ON "customer"."c_custkey" = "orders"."o_custkey" + GROUP BY + "customer"."c_custkey" +) AS "c_orders" +GROUP BY + "c_orders"."c_count" +ORDER BY + "custdist" DESC, + "c_count" DESC; + +-------------------------------------- +-- TPC-H 14 +-------------------------------------- +select + 100.00 * sum(case + when p_type like 'PROMO%' + then l_extendedprice * (1 - l_discount) + else 0 + end) / sum(l_extendedprice * (1 - l_discount)) as promo_revenue +from + lineitem, + part +where + l_partkey = p_partkey + and l_shipdate >= date '1995-09-01' + and l_shipdate < date '1995-09-01' + interval '1' month; +SELECT + 100.00 * SUM(CASE + WHEN "part"."p_type" LIKE 'PROMO%' + THEN "lineitem"."l_extendedprice" * ( + 1 - "lineitem"."l_discount" + ) + ELSE 0 + END) / SUM("lineitem"."l_extendedprice" * ( + 1 - "lineitem"."l_discount" + )) AS "promo_revenue" +FROM ( + SELECT + "lineitem"."l_partkey" AS "l_partkey", + "lineitem"."l_extendedprice" AS "l_extendedprice", + "lineitem"."l_discount" AS "l_discount", + "lineitem"."l_shipdate" AS "l_shipdate" + FROM "lineitem" AS "lineitem" + WHERE + "lineitem"."l_shipdate" < CAST('1995-10-01' AS DATE) + AND "lineitem"."l_shipdate" >= CAST('1995-09-01' AS DATE) +) AS "lineitem" +JOIN ( + SELECT + "part"."p_partkey" AS "p_partkey", + "part"."p_type" AS "p_type" + FROM "part" AS "part" +) AS "part" + ON "lineitem"."l_partkey" = "part"."p_partkey"; + +-------------------------------------- +-- TPC-H 15 +-------------------------------------- +with revenue (supplier_no, total_revenue) as ( + select + l_suppkey, + sum(l_extendedprice * (1 - l_discount)) + from + lineitem + where + l_shipdate >= date '1996-01-01' + and l_shipdate < date '1996-01-01' + interval '3' month + group by + l_suppkey) +select + s_suppkey, + s_name, + s_address, + s_phone, + total_revenue +from + supplier, + revenue +where + s_suppkey = supplier_no + and total_revenue = ( + select + max(total_revenue) + from + revenue + ) +order by + s_suppkey; +WITH "revenue" AS ( + SELECT + "lineitem"."l_suppkey" AS "supplier_no", + SUM("lineitem"."l_extendedprice" * ( + 1 - "lineitem"."l_discount" + )) AS "total_revenue" + FROM "lineitem" AS "lineitem" + WHERE + "lineitem"."l_shipdate" < CAST('1996-04-01' AS DATE) + AND "lineitem"."l_shipdate" >= CAST('1996-01-01' AS DATE) + GROUP BY + "lineitem"."l_suppkey" +) +SELECT + "supplier"."s_suppkey" AS "s_suppkey", + "supplier"."s_name" AS "s_name", + "supplier"."s_address" AS "s_address", + "supplier"."s_phone" AS "s_phone", + "revenue"."total_revenue" AS "total_revenue" +FROM ( + SELECT + "supplier"."s_suppkey" AS "s_suppkey", + "supplier"."s_name" AS "s_name", + "supplier"."s_address" AS "s_address", + "supplier"."s_phone" AS "s_phone" + FROM "supplier" AS "supplier" +) AS "supplier" +JOIN "revenue" + ON "revenue"."total_revenue" = ( + SELECT + MAX("revenue"."total_revenue") AS "_col_0" + FROM "revenue" + ) + AND "supplier"."s_suppkey" = "revenue"."supplier_no" +ORDER BY + "s_suppkey"; + +-------------------------------------- +-- TPC-H 16 +-------------------------------------- +select + p_brand, + p_type, + p_size, + count(distinct ps_suppkey) as supplier_cnt +from + partsupp, + part +where + p_partkey = ps_partkey + and p_brand <> 'Brand#45' + and p_type not like 'MEDIUM POLISHED%' + and p_size in (49, 14, 23, 45, 19, 3, 36, 9) + and ps_suppkey not in ( + select + s_suppkey + from + supplier + where + s_comment like '%Customer%Complaints%' + ) +group by + p_brand, + p_type, + p_size +order by + supplier_cnt desc, + p_brand, + p_type, + p_size; +SELECT + "part"."p_brand" AS "p_brand", + "part"."p_type" AS "p_type", + "part"."p_size" AS "p_size", + COUNT(DISTINCT "partsupp"."ps_suppkey") AS "supplier_cnt" +FROM ( + SELECT + "partsupp"."ps_partkey" AS "ps_partkey", + "partsupp"."ps_suppkey" AS "ps_suppkey" + FROM "partsupp" AS "partsupp" +) AS "partsupp" +LEFT JOIN ( + SELECT + "supplier"."s_suppkey" AS "s_suppkey" + FROM "supplier" AS "supplier" + WHERE + "supplier"."s_comment" LIKE '%Customer%Complaints%' + GROUP BY + "supplier"."s_suppkey" +) AS "_u_0" + ON "partsupp"."ps_suppkey" = "_u_0"."s_suppkey" +JOIN ( + SELECT + "part"."p_partkey" AS "p_partkey", + "part"."p_brand" AS "p_brand", + "part"."p_type" AS "p_type", + "part"."p_size" AS "p_size" + FROM "part" AS "part" + WHERE + "part"."p_brand" <> 'Brand#45' + AND "part"."p_size" IN (49, 14, 23, 45, 19, 3, 36, 9) + AND NOT "part"."p_type" LIKE 'MEDIUM POLISHED%' +) AS "part" + ON "part"."p_partkey" = "partsupp"."ps_partkey" +WHERE + "_u_0"."s_suppkey" IS NULL +GROUP BY + "part"."p_brand", + "part"."p_type", + "part"."p_size" +ORDER BY + "supplier_cnt" DESC, + "p_brand", + "p_type", + "p_size"; + +-------------------------------------- +-- TPC-H 17 +-------------------------------------- +select + sum(l_extendedprice) / 7.0 as avg_yearly +from + lineitem, + part +where + p_partkey = l_partkey + and p_brand = 'Brand#23' + and p_container = 'MED BOX' + and l_quantity < ( + select + 0.2 * avg(l_quantity) + from + lineitem + where + l_partkey = p_partkey + ); +SELECT + SUM("lineitem"."l_extendedprice") / 7.0 AS "avg_yearly" +FROM ( + SELECT + "lineitem"."l_partkey" AS "l_partkey", + "lineitem"."l_quantity" AS "l_quantity", + "lineitem"."l_extendedprice" AS "l_extendedprice" + FROM "lineitem" AS "lineitem" +) AS "lineitem" +JOIN ( + SELECT + "part"."p_partkey" AS "p_partkey", + "part"."p_brand" AS "p_brand", + "part"."p_container" AS "p_container" + FROM "part" AS "part" + WHERE + "part"."p_brand" = 'Brand#23' + AND "part"."p_container" = 'MED BOX' +) AS "part" + ON "part"."p_partkey" = "lineitem"."l_partkey" +LEFT JOIN ( + SELECT + 0.2 * AVG("lineitem"."l_quantity") AS "_col_0", + "lineitem"."l_partkey" AS "_u_1" + FROM "lineitem" AS "lineitem" + GROUP BY + "lineitem"."l_partkey" +) AS "_u_0" + ON "_u_0"."_u_1" = "part"."p_partkey" +WHERE + "lineitem"."l_quantity" < "_u_0"."_col_0" + AND NOT "_u_0"."_u_1" IS NULL; + +-------------------------------------- +-- TPC-H 18 +-------------------------------------- +select + c_name, + c_custkey, + o_orderkey, + o_orderdate, + o_totalprice, + sum(l_quantity) +from + customer, + orders, + lineitem +where + o_orderkey in ( + select + l_orderkey + from + lineitem + group by + l_orderkey having + sum(l_quantity) > 300 + ) + and c_custkey = o_custkey + and o_orderkey = l_orderkey +group by + c_name, + c_custkey, + o_orderkey, + o_orderdate, + o_totalprice +order by + o_totalprice desc, + o_orderdate +limit + 100; +SELECT + "customer"."c_name" AS "c_name", + "customer"."c_custkey" AS "c_custkey", + "orders"."o_orderkey" AS "o_orderkey", + "orders"."o_orderdate" AS "o_orderdate", + "orders"."o_totalprice" AS "o_totalprice", + SUM("lineitem"."l_quantity") AS "_col_5" +FROM ( + SELECT + "customer"."c_custkey" AS "c_custkey", + "customer"."c_name" AS "c_name" + FROM "customer" AS "customer" +) AS "customer" +JOIN ( + SELECT + "orders"."o_orderkey" AS "o_orderkey", + "orders"."o_custkey" AS "o_custkey", + "orders"."o_totalprice" AS "o_totalprice", + "orders"."o_orderdate" AS "o_orderdate" + FROM "orders" AS "orders" +) AS "orders" + ON "customer"."c_custkey" = "orders"."o_custkey" +LEFT JOIN ( + SELECT + "lineitem"."l_orderkey" AS "l_orderkey" + FROM "lineitem" AS "lineitem" + GROUP BY + "lineitem"."l_orderkey", + "lineitem"."l_orderkey" + HAVING + SUM("lineitem"."l_quantity") > 300 +) AS "_u_0" + ON "orders"."o_orderkey" = "_u_0"."l_orderkey" +JOIN ( + SELECT + "lineitem"."l_orderkey" AS "l_orderkey", + "lineitem"."l_quantity" AS "l_quantity" + FROM "lineitem" AS "lineitem" +) AS "lineitem" + ON "orders"."o_orderkey" = "lineitem"."l_orderkey" +WHERE + NOT "_u_0"."l_orderkey" IS NULL +GROUP BY + "customer"."c_name", + "customer"."c_custkey", + "orders"."o_orderkey", + "orders"."o_orderdate", + "orders"."o_totalprice" +ORDER BY + "o_totalprice" DESC, + "o_orderdate" +LIMIT 100; + +-------------------------------------- +-- TPC-H 19 +-------------------------------------- +select + sum(l_extendedprice* (1 - l_discount)) as revenue +from + lineitem, + part +where + ( + p_partkey = l_partkey + and p_brand = 'Brand#12' + and p_container in ('SM CASE', 'SM BOX', 'SM PACK', 'SM PKG') + and l_quantity >= 1 and l_quantity <= 11 + and p_size between 1 and 5 + and l_shipmode in ('AIR', 'AIR REG') + and l_shipinstruct = 'DELIVER IN PERSON' + ) + or + ( + p_partkey = l_partkey + and p_brand = 'Brand#23' + and p_container in ('MED BAG', 'MED BOX', 'MED PKG', 'MED PACK') + and l_quantity >= 10 and l_quantity <= 20 + and p_size between 1 and 10 + and l_shipmode in ('AIR', 'AIR REG') + and l_shipinstruct = 'DELIVER IN PERSON' + ) + or + ( + p_partkey = l_partkey + and p_brand = 'Brand#34' + and p_container in ('LG CASE', 'LG BOX', 'LG PACK', 'LG PKG') + and l_quantity >= 20 and l_quantity <= 30 + and p_size between 1 and 15 + and l_shipmode in ('AIR', 'AIR REG') + and l_shipinstruct = 'DELIVER IN PERSON' + ); +SELECT + SUM("lineitem"."l_extendedprice" * ( + 1 - "lineitem"."l_discount" + )) AS "revenue" +FROM ( + SELECT + "lineitem"."l_partkey" AS "l_partkey", + "lineitem"."l_quantity" AS "l_quantity", + "lineitem"."l_extendedprice" AS "l_extendedprice", + "lineitem"."l_discount" AS "l_discount", + "lineitem"."l_shipinstruct" AS "l_shipinstruct", + "lineitem"."l_shipmode" AS "l_shipmode" + FROM "lineitem" AS "lineitem" +) AS "lineitem" +JOIN ( + SELECT + "part"."p_partkey" AS "p_partkey", + "part"."p_brand" AS "p_brand", + "part"."p_size" AS "p_size", + "part"."p_container" AS "p_container" + FROM "part" AS "part" +) AS "part" + ON ( + "part"."p_brand" = 'Brand#12' + AND "part"."p_container" IN ('SM CASE', 'SM BOX', 'SM PACK', 'SM PKG') + AND "part"."p_partkey" = "lineitem"."l_partkey" + AND "part"."p_size" BETWEEN 1 AND 5 + ) + OR ( + "part"."p_brand" = 'Brand#23' + AND "part"."p_container" IN ('MED BAG', 'MED BOX', 'MED PKG', 'MED PACK') + AND "part"."p_partkey" = "lineitem"."l_partkey" + AND "part"."p_size" BETWEEN 1 AND 10 + ) + OR ( + "part"."p_brand" = 'Brand#34' + AND "part"."p_container" IN ('LG CASE', 'LG BOX', 'LG PACK', 'LG PKG') + AND "part"."p_partkey" = "lineitem"."l_partkey" + AND "part"."p_size" BETWEEN 1 AND 15 + ) +WHERE + ( + "lineitem"."l_quantity" <= 11 + AND "lineitem"."l_quantity" >= 1 + AND "lineitem"."l_shipinstruct" = 'DELIVER IN PERSON' + AND "lineitem"."l_shipmode" IN ('AIR', 'AIR REG') + AND "part"."p_brand" = 'Brand#12' + AND "part"."p_container" IN ('SM CASE', 'SM BOX', 'SM PACK', 'SM PKG') + AND "part"."p_partkey" = "lineitem"."l_partkey" + AND "part"."p_size" BETWEEN 1 AND 5 + ) + OR ( + "lineitem"."l_quantity" <= 20 + AND "lineitem"."l_quantity" >= 10 + AND "lineitem"."l_shipinstruct" = 'DELIVER IN PERSON' + AND "lineitem"."l_shipmode" IN ('AIR', 'AIR REG') + AND "part"."p_brand" = 'Brand#23' + AND "part"."p_container" IN ('MED BAG', 'MED BOX', 'MED PKG', 'MED PACK') + AND "part"."p_partkey" = "lineitem"."l_partkey" + AND "part"."p_size" BETWEEN 1 AND 10 + ) + OR ( + "lineitem"."l_quantity" <= 30 + AND "lineitem"."l_quantity" >= 20 + AND "lineitem"."l_shipinstruct" = 'DELIVER IN PERSON' + AND "lineitem"."l_shipmode" IN ('AIR', 'AIR REG') + AND "part"."p_brand" = 'Brand#34' + AND "part"."p_container" IN ('LG CASE', 'LG BOX', 'LG PACK', 'LG PKG') + AND "part"."p_partkey" = "lineitem"."l_partkey" + AND "part"."p_size" BETWEEN 1 AND 15 + ); + +-------------------------------------- +-- TPC-H 20 +-------------------------------------- +select + s_name, + s_address +from + supplier, + nation +where + s_suppkey in ( + select + ps_suppkey + from + partsupp + where + ps_partkey in ( + select + p_partkey + from + part + where + p_name like 'forest%' + ) + and ps_availqty > ( + select + 0.5 * sum(l_quantity) + from + lineitem + where + l_partkey = ps_partkey + and l_suppkey = ps_suppkey + and l_shipdate >= date '1994-01-01' + and l_shipdate < date '1994-01-01' + interval '1' year + ) + ) + and s_nationkey = n_nationkey + and n_name = 'CANADA' +order by + s_name; +SELECT + "supplier"."s_name" AS "s_name", + "supplier"."s_address" AS "s_address" +FROM ( + SELECT + "supplier"."s_suppkey" AS "s_suppkey", + "supplier"."s_name" AS "s_name", + "supplier"."s_address" AS "s_address", + "supplier"."s_nationkey" AS "s_nationkey" + FROM "supplier" AS "supplier" +) AS "supplier" +LEFT JOIN ( + SELECT + "partsupp"."ps_suppkey" AS "ps_suppkey" + FROM "partsupp" AS "partsupp" + LEFT JOIN ( + SELECT + 0.5 * SUM("lineitem"."l_quantity") AS "_col_0", + "lineitem"."l_partkey" AS "_u_1", + "lineitem"."l_suppkey" AS "_u_2" + FROM "lineitem" AS "lineitem" + WHERE + "lineitem"."l_shipdate" < CAST('1995-01-01' AS DATE) + AND "lineitem"."l_shipdate" >= CAST('1994-01-01' AS DATE) + GROUP BY + "lineitem"."l_partkey", + "lineitem"."l_suppkey" + ) AS "_u_0" + ON "_u_0"."_u_1" = "partsupp"."ps_partkey" + AND "_u_0"."_u_2" = "partsupp"."ps_suppkey" + LEFT JOIN ( + SELECT + "part"."p_partkey" AS "p_partkey" + FROM "part" AS "part" + WHERE + "part"."p_name" LIKE 'forest%' + GROUP BY + "part"."p_partkey" + ) AS "_u_3" + ON "partsupp"."ps_partkey" = "_u_3"."p_partkey" + WHERE + "partsupp"."ps_availqty" > "_u_0"."_col_0" + AND NOT "_u_0"."_u_1" IS NULL + AND NOT "_u_0"."_u_2" IS NULL + AND NOT "_u_3"."p_partkey" IS NULL + GROUP BY + "partsupp"."ps_suppkey" +) AS "_u_4" + ON "supplier"."s_suppkey" = "_u_4"."ps_suppkey" +JOIN ( + SELECT + "nation"."n_nationkey" AS "n_nationkey", + "nation"."n_name" AS "n_name" + FROM "nation" AS "nation" + WHERE + "nation"."n_name" = 'CANADA' +) AS "nation" + ON "supplier"."s_nationkey" = "nation"."n_nationkey" +WHERE + NOT "_u_4"."ps_suppkey" IS NULL +ORDER BY + "s_name"; + +-------------------------------------- +-- TPC-H 21 +-------------------------------------- +select + s_name, + count(*) as numwait +from + supplier, + lineitem l1, + orders, + nation +where + s_suppkey = l1.l_suppkey + and o_orderkey = l1.l_orderkey + and o_orderstatus = 'F' + and l1.l_receiptdate > l1.l_commitdate + and exists ( + select + * + from + lineitem l2 + where + l2.l_orderkey = l1.l_orderkey + and l2.l_suppkey <> l1.l_suppkey + ) + and not exists ( + select + * + from + lineitem l3 + where + l3.l_orderkey = l1.l_orderkey + and l3.l_suppkey <> l1.l_suppkey + and l3.l_receiptdate > l3.l_commitdate + ) + and s_nationkey = n_nationkey + and n_name = 'SAUDI ARABIA' +group by + s_name +order by + numwait desc, + s_name +limit + 100; +SELECT + "supplier"."s_name" AS "s_name", + COUNT(*) AS "numwait" +FROM ( + SELECT + "supplier"."s_suppkey" AS "s_suppkey", + "supplier"."s_name" AS "s_name", + "supplier"."s_nationkey" AS "s_nationkey" + FROM "supplier" AS "supplier" +) AS "supplier" +JOIN ( + SELECT + "lineitem"."l_orderkey" AS "l_orderkey", + "lineitem"."l_suppkey" AS "l_suppkey", + "lineitem"."l_commitdate" AS "l_commitdate", + "lineitem"."l_receiptdate" AS "l_receiptdate" + FROM "lineitem" AS "lineitem" + WHERE + "lineitem"."l_receiptdate" > "lineitem"."l_commitdate" +) AS "l1" + ON "supplier"."s_suppkey" = "l1"."l_suppkey" +LEFT JOIN ( + SELECT + "l2"."l_orderkey" AS "l_orderkey", + ARRAY_AGG("l2"."l_suppkey") AS "_u_1" + FROM "lineitem" AS "l2" + GROUP BY + "l2"."l_orderkey" +) AS "_u_0" + ON "_u_0"."l_orderkey" = "l1"."l_orderkey" +LEFT JOIN ( + SELECT + "l3"."l_orderkey" AS "l_orderkey", + ARRAY_AGG("l3"."l_suppkey") AS "_u_3" + FROM "lineitem" AS "l3" + WHERE + "l3"."l_receiptdate" > "l3"."l_commitdate" + GROUP BY + "l3"."l_orderkey" +) AS "_u_2" + ON "_u_2"."l_orderkey" = "l1"."l_orderkey" +JOIN ( + SELECT + "orders"."o_orderkey" AS "o_orderkey", + "orders"."o_orderstatus" AS "o_orderstatus" + FROM "orders" AS "orders" + WHERE + "orders"."o_orderstatus" = 'F' +) AS "orders" + ON "orders"."o_orderkey" = "l1"."l_orderkey" +JOIN ( + SELECT + "nation"."n_nationkey" AS "n_nationkey", + "nation"."n_name" AS "n_name" + FROM "nation" AS "nation" + WHERE + "nation"."n_name" = 'SAUDI ARABIA' +) AS "nation" + ON "supplier"."s_nationkey" = "nation"."n_nationkey" +WHERE + ( + "_u_2"."l_orderkey" IS NULL + OR NOT ARRAY_ANY("_u_2"."_u_3", "_x" -> "_x" <> "l1"."l_suppkey") + ) + AND ARRAY_ANY("_u_0"."_u_1", "_x" -> "_x" <> "l1"."l_suppkey") + AND NOT "_u_0"."l_orderkey" IS NULL +GROUP BY + "supplier"."s_name" +ORDER BY + "numwait" DESC, + "s_name" +LIMIT 100; + +-------------------------------------- +-- TPC-H 22 +-------------------------------------- +select + cntrycode, + count(*) as numcust, + sum(c_acctbal) as totacctbal +from + ( + select + substring(c_phone, 1, 2) as cntrycode, + c_acctbal + from + customer + where + substring(c_phone, 1, 2) in + ('13', '31', '23', '29', '30', '18', '17') + and c_acctbal > ( + select + avg(c_acctbal) + from + customer + where + c_acctbal > 0.00 + and substring(c_phone, 1, 2) in + ('13', '31', '23', '29', '30', '18', '17') + ) + and not exists ( + select + * + from + orders + where + o_custkey = c_custkey + ) + ) as custsale +group by + cntrycode +order by + cntrycode; +SELECT + "custsale"."cntrycode" AS "cntrycode", + COUNT(*) AS "numcust", + SUM("custsale"."c_acctbal") AS "totacctbal" +FROM ( + SELECT + SUBSTRING("customer"."c_phone", 1, 2) AS "cntrycode", + "customer"."c_acctbal" AS "c_acctbal" + FROM "customer" AS "customer" + LEFT JOIN ( + SELECT + "orders"."o_custkey" AS "_u_1" + FROM "orders" AS "orders" + GROUP BY + "orders"."o_custkey" + ) AS "_u_0" + ON "_u_0"."_u_1" = "customer"."c_custkey" + WHERE + "_u_0"."_u_1" IS NULL + AND "customer"."c_acctbal" > ( + SELECT + AVG("customer"."c_acctbal") AS "_col_0" + FROM "customer" AS "customer" + WHERE + "customer"."c_acctbal" > 0.00 + AND SUBSTRING("customer"."c_phone", 1, 2) IN ('13', '31', '23', '29', '30', '18', '17') + ) + AND SUBSTRING("customer"."c_phone", 1, 2) IN ('13', '31', '23', '29', '30', '18', '17') +) AS "custsale" +GROUP BY + "custsale"."cntrycode" +ORDER BY + "cntrycode"; diff --git a/tests/fixtures/optimizer/unnest_subqueries.sql b/tests/fixtures/optimizer/unnest_subqueries.sql new file mode 100644 index 0000000..9c4bd27 --- /dev/null +++ b/tests/fixtures/optimizer/unnest_subqueries.sql @@ -0,0 +1,206 @@ +-------------------------------------- +-- Unnest Subqueries +-------------------------------------- +SELECT * +FROM x AS x +WHERE + x.a IN (SELECT y.a AS a FROM y) + AND x.a IN (SELECT y.b AS b FROM y) + AND x.a = ANY (SELECT y.a AS a FROM y) + AND x.a = (SELECT SUM(y.b) AS b FROM y WHERE x.a = y.a) + AND x.a > (SELECT SUM(y.b) AS b FROM y WHERE x.a = y.a) + AND x.a <> ANY (SELECT y.a AS a FROM y WHERE y.a = x.a) + AND x.a NOT IN (SELECT y.a AS a FROM y WHERE y.a = x.a) + AND x.a IN (SELECT y.a AS a FROM y WHERE y.b = x.a) + AND x.a < (SELECT SUM(y.a) AS a FROM y WHERE y.a = x.a and y.a = x.b and y.b <> x.d) + AND EXISTS (SELECT y.a AS a, y.b AS b FROM y WHERE x.a = y.a) + AND x.a IN (SELECT y.a AS a FROM y LIMIT 10) + AND x.a IN (SELECT y.a AS a FROM y OFFSET 10) + AND x.a IN (SELECT y.a AS a, y.b AS b FROM y) + AND x.a > ANY (SELECT y.a FROM y) + AND x.a = (SELECT SUM(y.c) AS c FROM y WHERE y.a = x.a LIMIT 10) + AND x.a = (SELECT SUM(y.c) AS c FROM y WHERE y.a = x.a OFFSET 10) +; +SELECT + * +FROM x AS x +LEFT JOIN ( + SELECT + y.a AS a + FROM y + GROUP BY + y.a +) AS "_u_0" + ON x.a = "_u_0"."a" +LEFT JOIN ( + SELECT + y.b AS b + FROM y + GROUP BY + y.b +) AS "_u_1" + ON x.a = "_u_1"."b" +LEFT JOIN ( + SELECT + y.a AS a + FROM y + GROUP BY + y.a +) AS "_u_2" + ON x.a = "_u_2"."a" +LEFT JOIN ( + SELECT + SUM(y.b) AS b, + y.a AS _u_4 + FROM y + WHERE + TRUE + GROUP BY + y.a +) AS "_u_3" + ON x.a = "_u_3"."_u_4" +LEFT JOIN ( + SELECT + SUM(y.b) AS b, + y.a AS _u_6 + FROM y + WHERE + TRUE + GROUP BY + y.a +) AS "_u_5" + ON x.a = "_u_5"."_u_6" +LEFT JOIN ( + SELECT + y.a AS a + FROM y + WHERE + TRUE + GROUP BY + y.a +) AS "_u_7" + ON "_u_7".a = x.a +LEFT JOIN ( + SELECT + y.a AS a + FROM y + WHERE + TRUE + GROUP BY + y.a +) AS "_u_8" + ON "_u_8".a = x.a +LEFT JOIN ( + SELECT + ARRAY_AGG(y.a) AS a, + y.b AS _u_10 + FROM y + WHERE + TRUE + GROUP BY + y.b +) AS "_u_9" + ON "_u_9"."_u_10" = x.a +LEFT JOIN ( + SELECT + SUM(y.a) AS a, + y.a AS _u_12, + ARRAY_AGG(y.b) AS _u_13 + FROM y + WHERE + TRUE + AND TRUE + AND TRUE + GROUP BY + y.a +) AS "_u_11" + ON "_u_11"."_u_12" = x.a + AND "_u_11"."_u_12" = x.b +LEFT JOIN ( + SELECT + y.a AS a + FROM y + WHERE + TRUE + GROUP BY + y.a +) AS "_u_14" + ON x.a = "_u_14".a +WHERE + NOT "_u_0"."a" IS NULL + AND NOT "_u_1"."b" IS NULL + AND NOT "_u_2"."a" IS NULL + AND ( + x.a = "_u_3".b + AND NOT "_u_3"."_u_4" IS NULL + ) + AND ( + x.a > "_u_5".b + AND NOT "_u_5"."_u_6" IS NULL + ) + AND ( + None = "_u_7".a + AND NOT "_u_7".a IS NULL + ) + AND NOT ( + x.a = "_u_8".a + AND NOT "_u_8".a IS NULL + ) + AND ( + ARRAY_ANY("_u_9".a, _x -> _x = x.a) + AND NOT "_u_9"."_u_10" IS NULL + ) + AND ( + ( + ( + x.a < "_u_11".a + AND NOT "_u_11"."_u_12" IS NULL + ) + AND NOT "_u_11"."_u_12" IS NULL + ) + AND ARRAY_ANY("_u_11"."_u_13", "_x" -> "_x" <> x.d) + ) + AND ( + NOT "_u_14".a IS NULL + AND NOT "_u_14".a IS NULL + ) + AND x.a IN ( + SELECT + y.a AS a + FROM y + LIMIT 10 + ) + AND x.a IN ( + SELECT + y.a AS a + FROM y + OFFSET 10 + ) + AND x.a IN ( + SELECT + y.a AS a, + y.b AS b + FROM y + ) + AND x.a > ANY ( + SELECT + y.a + FROM y + ) + AND x.a = ( + SELECT + SUM(y.c) AS c + FROM y + WHERE + y.a = x.a + LIMIT 10 + ) + AND x.a = ( + SELECT + SUM(y.c) AS c + FROM y + WHERE + y.a = x.a + OFFSET 10 + ); + diff --git a/tests/fixtures/partial.sql b/tests/fixtures/partial.sql new file mode 100644 index 0000000..c6be364 --- /dev/null +++ b/tests/fixtures/partial.sql @@ -0,0 +1,8 @@ +SELECT a FROM +SELECT a FROM x WHERE +SELECT a + +a * +SELECT a FROM x JOIN +SELECT a FROM x GROUP BY +WITH a AS (SELECT 1), b AS (SELECT 2) +SELECT FROM x diff --git a/tests/fixtures/pretty.sql b/tests/fixtures/pretty.sql new file mode 100644 index 0000000..5ed74f4 --- /dev/null +++ b/tests/fixtures/pretty.sql @@ -0,0 +1,285 @@ +SELECT * FROM test; +SELECT + * +FROM test; + +WITH a AS ((SELECT 1 AS b) UNION ALL (SELECT 2 AS b)) SELECT * FROM a; +WITH a AS ( + ( + SELECT + 1 AS b + ) + UNION ALL + ( + SELECT + 2 AS b + ) +) +SELECT + * +FROM a; + +WITH cte1 AS ( + SELECT a, z and e AS b + FROM cte + WHERE x IN (1, 2, 3) AND z < -1 OR z > 1 AND w = 'AND' +), cte2 AS ( + SELECT RANK() OVER (PARTITION BY a, b ORDER BY x DESC) a, b + FROM cte + CROSS JOIN ( + SELECT 1 + UNION ALL + SELECT 2 + UNION ALL + SELECT CASE x AND 1 + 1 = 2 + WHEN TRUE THEN 1 AND 4 + 3 AND Z + WHEN x and y THEN 2 + ELSE 3 AND 4 AND g END + UNION ALL + SELECT 1 + FROM (SELECT 1) AS x, y, (SELECT 2) z + UNION ALL + SELECT MAX(COALESCE(x AND y, a and b and c, d and e)), FOO(CASE WHEN a and b THEN c and d ELSE 3 END) + GROUP BY x, GROUPING SETS (a, (b, c)) CUBE(y, z) + ) x +) +SELECT a, b c FROM ( + SELECT a w, 1 + 1 AS c + FROM foo + WHERE w IN (SELECT z FROM q) + GROUP BY a, b +) x +LEFT JOIN ( + SELECT a, b + FROM (SELECT * FROM bar WHERE (c > 1 AND d > 1) OR e > 1 GROUP BY a HAVING a > 1 LIMIT 10) z +) y ON x.a = y.b AND x.a > 1 OR (x.c = y.d OR x.c = y.e); +WITH cte1 AS ( + SELECT + a, + z + AND e AS b + FROM cte + WHERE + x IN (1, 2, 3) + AND z < -1 + OR z > 1 + AND w = 'AND' +), cte2 AS ( + SELECT + RANK() OVER (PARTITION BY a, b ORDER BY x DESC) AS a, + b + FROM cte + CROSS JOIN ( + SELECT + 1 + UNION ALL + SELECT + 2 + UNION ALL + SELECT + CASE x + AND 1 + 1 = 2 + WHEN TRUE + THEN 1 + AND 4 + 3 + AND Z + WHEN x + AND y + THEN 2 + ELSE 3 + AND 4 + AND g + END + UNION ALL + SELECT + 1 + FROM ( + SELECT + 1 + ) AS x, y, ( + SELECT + 2 + ) AS z + UNION ALL + SELECT + MAX(COALESCE(x + AND y, a + AND b + AND c, d + AND e)), + FOO(CASE + WHEN a + AND b + THEN c + AND d + ELSE 3 + END) + GROUP BY + x + GROUPING SETS ( + a, + (b, c) + ) + CUBE ( + y, + z + ) + ) AS x +) +SELECT + a, + b AS c +FROM ( + SELECT + a AS w, + 1 + 1 AS c + FROM foo + WHERE + w IN ( + SELECT + z + FROM q + ) + GROUP BY + a, + b +) AS x +LEFT JOIN ( + SELECT + a, + b + FROM ( + SELECT + * + FROM bar + WHERE + ( + c > 1 + AND d > 1 + ) + OR e > 1 + GROUP BY + a + HAVING + a > 1 + LIMIT 10 + ) AS z +) AS y + ON x.a = y.b + AND x.a > 1 + OR ( + x.c = y.d + OR x.c = y.e + ); + +SELECT myCol1, myCol2 FROM baseTable LATERAL VIEW OUTER explode(col1) myTable1 AS myCol1 LATERAL VIEW explode(col2) myTable2 AS myCol2 +where a > 1 and b > 2 or c > 3; + +SELECT + myCol1, + myCol2 +FROM baseTable +LATERAL VIEW OUTER +EXPLODE(col1) myTable1 AS myCol1 +LATERAL VIEW +EXPLODE(col2) myTable2 AS myCol2 +WHERE + a > 1 + AND b > 2 + OR c > 3; + +SELECT * FROM (WITH y AS ( SELECT 1 AS z) SELECT z from y) x; +SELECT + * +FROM ( + WITH y AS ( + SELECT + 1 AS z + ) + SELECT + z + FROM y +) AS x; + +INSERT OVERWRITE TABLE x VALUES (1, 2.0, '3.0'), (4, 5.0, '6.0'); +INSERT OVERWRITE TABLE x VALUES + (1, 2.0, '3.0'), + (4, 5.0, '6.0'); + +WITH regional_sales AS ( + SELECT region, SUM(amount) AS total_sales + FROM orders + GROUP BY region + ), top_regions AS ( + SELECT region + FROM regional_sales + WHERE total_sales > (SELECT SUM(total_sales)/10 FROM regional_sales) +) +SELECT region, +product, +SUM(quantity) AS product_units, +SUM(amount) AS product_sales +FROM orders +WHERE region IN (SELECT region FROM top_regions) +GROUP BY region, product; +WITH regional_sales AS ( + SELECT + region, + SUM(amount) AS total_sales + FROM orders + GROUP BY + region +), top_regions AS ( + SELECT + region + FROM regional_sales + WHERE + total_sales > ( + SELECT + SUM(total_sales) / 10 + FROM regional_sales + ) +) +SELECT + region, + product, + SUM(quantity) AS product_units, + SUM(amount) AS product_sales +FROM orders +WHERE + region IN ( + SELECT + region + FROM top_regions + ) +GROUP BY + region, + product; + +CREATE TABLE "t_customer_account" ( "id" int, "customer_id" int, "bank" varchar(100), "account_no" varchar(100)); +CREATE TABLE "t_customer_account" ( + "id" INT, + "customer_id" INT, + "bank" VARCHAR(100), + "account_no" VARCHAR(100) +); + +CREATE TABLE "t_customer_account" ( + "id" int(11) NOT NULL AUTO_INCREMENT, + "customer_id" int(11) DEFAULT NULL COMMENT '客户id', + "bank" varchar(100) COLLATE utf8_bin DEFAULT NULL COMMENT '行别', + "account_no" varchar(100) COLLATE utf8_bin DEFAULT NULL COMMENT '账号', + PRIMARY KEY ("id") +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARACTER SET=utf8 COLLATE=utf8_bin COMMENT='客户账户表'; +CREATE TABLE "t_customer_account" ( + "id" INT(11) NOT NULL AUTO_INCREMENT, + "customer_id" INT(11) DEFAULT NULL COMMENT '客户id', + "bank" VARCHAR(100) COLLATE utf8_bin DEFAULT NULL COMMENT '行别', + "account_no" VARCHAR(100) COLLATE utf8_bin DEFAULT NULL COMMENT '账号', + PRIMARY KEY("id") +) +ENGINE=InnoDB +AUTO_INCREMENT=1 +DEFAULT CHARACTER SET=utf8 +COLLATE=utf8_bin +COMMENT='客户账户表'; diff --git a/tests/helpers.py b/tests/helpers.py new file mode 100644 index 0000000..d4edb14 --- /dev/null +++ b/tests/helpers.py @@ -0,0 +1,130 @@ +import os + +FILE_DIR = os.path.dirname(__file__) +FIXTURES_DIR = os.path.join(FILE_DIR, "fixtures") + + +def _filter_comments(s): + return "\n".join( + [line for line in s.splitlines() if line and not line.startswith("--")] + ) + + +def _extract_meta(sql): + meta = {} + sql_lines = sql.split("\n") + i = 0 + while sql_lines[i].startswith("#"): + key, val = sql_lines[i].split(":", maxsplit=1) + meta[key.lstrip("#").strip()] = val.strip() + i += 1 + sql = "\n".join(sql_lines[i:]) + return sql, meta + + +def assert_logger_contains(message, logger, level="error"): + output = "\n".join( + str(args[0][0]) for args in getattr(logger, level).call_args_list + ) + assert message in output + + +def load_sql_fixtures(filename): + with open(os.path.join(FIXTURES_DIR, filename), encoding="utf-8") as f: + for sql in _filter_comments(f.read()).splitlines(): + yield sql + + +def load_sql_fixture_pairs(filename): + with open(os.path.join(FIXTURES_DIR, filename), encoding="utf-8") as f: + statements = _filter_comments(f.read()).split(";") + + size = len(statements) + + for i in range(0, size, 2): + if i + 1 < size: + sql = statements[i].strip() + sql, meta = _extract_meta(sql) + expected = statements[i + 1].strip() + yield meta, sql, expected + + +TPCH_SCHEMA = { + "lineitem": { + "l_orderkey": "uint64", + "l_partkey": "uint64", + "l_suppkey": "uint64", + "l_linenumber": "uint64", + "l_quantity": "float64", + "l_extendedprice": "float64", + "l_discount": "float64", + "l_tax": "float64", + "l_returnflag": "string", + "l_linestatus": "string", + "l_shipdate": "date32", + "l_commitdate": "date32", + "l_receiptdate": "date32", + "l_shipinstruct": "string", + "l_shipmode": "string", + "l_comment": "string", + }, + "orders": { + "o_orderkey": "uint64", + "o_custkey": "uint64", + "o_orderstatus": "string", + "o_totalprice": "float64", + "o_orderdate": "date32", + "o_orderpriority": "string", + "o_clerk": "string", + "o_shippriority": "int32", + "o_comment": "string", + }, + "customer": { + "c_custkey": "uint64", + "c_name": "string", + "c_address": "string", + "c_nationkey": "uint64", + "c_phone": "string", + "c_acctbal": "float64", + "c_mktsegment": "string", + "c_comment": "string", + }, + "part": { + "p_partkey": "uint64", + "p_name": "string", + "p_mfgr": "string", + "p_brand": "string", + "p_type": "string", + "p_size": "int32", + "p_container": "string", + "p_retailprice": "float64", + "p_comment": "string", + }, + "supplier": { + "s_suppkey": "uint64", + "s_name": "string", + "s_address": "string", + "s_nationkey": "uint64", + "s_phone": "string", + "s_acctbal": "float64", + "s_comment": "string", + }, + "partsupp": { + "ps_partkey": "uint64", + "ps_suppkey": "uint64", + "ps_availqty": "int32", + "ps_supplycost": "float64", + "ps_comment": "string", + }, + "nation": { + "n_nationkey": "uint64", + "n_name": "string", + "n_regionkey": "uint64", + "n_comment": "string", + }, + "region": { + "r_regionkey": "uint64", + "r_name": "string", + "r_comment": "string", + }, +} diff --git a/tests/test_build.py b/tests/test_build.py new file mode 100644 index 0000000..a4cffde --- /dev/null +++ b/tests/test_build.py @@ -0,0 +1,384 @@ +import unittest + +from sqlglot import and_, condition, exp, from_, not_, or_, parse_one, select + + +class TestBuild(unittest.TestCase): + def test_build(self): + for expression, sql, *dialect in [ + (lambda: select("x"), "SELECT x"), + (lambda: select("x", "y"), "SELECT x, y"), + (lambda: select("x").from_("tbl"), "SELECT x FROM tbl"), + (lambda: select("x", "y").from_("tbl"), "SELECT x, y FROM tbl"), + (lambda: select("x").select("y").from_("tbl"), "SELECT x, y FROM tbl"), + ( + lambda: select("x").select("y", append=False).from_("tbl"), + "SELECT y FROM tbl", + ), + (lambda: select("x").from_("tbl").from_("tbl2"), "SELECT x FROM tbl, tbl2"), + ( + lambda: select("x").from_("tbl, tbl2", "tbl3").from_("tbl4"), + "SELECT x FROM tbl, tbl2, tbl3, tbl4", + ), + ( + lambda: select("x").from_("tbl").from_("tbl2", append=False), + "SELECT x FROM tbl2", + ), + (lambda: select("SUM(x) AS y"), "SELECT SUM(x) AS y"), + ( + lambda: select("x").from_("tbl").where("x > 0"), + "SELECT x FROM tbl WHERE x > 0", + ), + ( + lambda: select("x").from_("tbl").where("x < 4 OR x > 5"), + "SELECT x FROM tbl WHERE x < 4 OR x > 5", + ), + ( + lambda: select("x").from_("tbl").where("x > 0").where("x < 9"), + "SELECT x FROM tbl WHERE x > 0 AND x < 9", + ), + ( + lambda: select("x").from_("tbl").where("x > 0", "x < 9"), + "SELECT x FROM tbl WHERE x > 0 AND x < 9", + ), + ( + lambda: select("x").from_("tbl").where(None).where(False, ""), + "SELECT x FROM tbl WHERE FALSE", + ), + ( + lambda: select("x") + .from_("tbl") + .where("x > 0") + .where("x < 9", append=False), + "SELECT x FROM tbl WHERE x < 9", + ), + ( + lambda: select("x", "y").from_("tbl").group_by("x"), + "SELECT x, y FROM tbl GROUP BY x", + ), + ( + lambda: select("x", "y").from_("tbl").group_by("x, y"), + "SELECT x, y FROM tbl GROUP BY x, y", + ), + ( + lambda: select("x", "y", "z", "a") + .from_("tbl") + .group_by("x, y", "z") + .group_by("a"), + "SELECT x, y, z, a FROM tbl GROUP BY x, y, z, a", + ), + ( + lambda: select("x").distinct(True).from_("tbl"), + "SELECT DISTINCT x FROM tbl", + ), + (lambda: select("x").distinct(False).from_("tbl"), "SELECT x FROM tbl"), + ( + lambda: select("x").lateral("OUTER explode(y) tbl2 AS z").from_("tbl"), + "SELECT x FROM tbl LATERAL VIEW OUTER EXPLODE(y) tbl2 AS z", + ), + ( + lambda: select("x").from_("tbl").join("tbl2 ON tbl.y = tbl2.y"), + "SELECT x FROM tbl JOIN tbl2 ON tbl.y = tbl2.y", + ), + ( + lambda: select("x").from_("tbl").join("tbl2", on="tbl.y = tbl2.y"), + "SELECT x FROM tbl JOIN tbl2 ON tbl.y = tbl2.y", + ), + ( + lambda: select("x") + .from_("tbl") + .join("tbl2", on=["tbl.y = tbl2.y", "a = b"]), + "SELECT x FROM tbl JOIN tbl2 ON tbl.y = tbl2.y AND a = b", + ), + ( + lambda: select("x").from_("tbl").join("tbl2", join_type="left outer"), + "SELECT x FROM tbl LEFT OUTER JOIN tbl2", + ), + ( + lambda: select("x") + .from_("tbl") + .join(exp.Table(this="tbl2"), join_type="left outer"), + "SELECT x FROM tbl LEFT OUTER JOIN tbl2", + ), + ( + lambda: select("x") + .from_("tbl") + .join(exp.Table(this="tbl2"), join_type="left outer", join_alias="foo"), + "SELECT x FROM tbl LEFT OUTER JOIN tbl2 AS foo", + ), + ( + lambda: select("x") + .from_("tbl") + .join(select("y").from_("tbl2"), join_type="left outer"), + "SELECT x FROM tbl LEFT OUTER JOIN (SELECT y FROM tbl2)", + ), + ( + lambda: select("x") + .from_("tbl") + .join( + select("y").from_("tbl2").subquery("aliased"), + join_type="left outer", + ), + "SELECT x FROM tbl LEFT OUTER JOIN (SELECT y FROM tbl2) AS aliased", + ), + ( + lambda: select("x") + .from_("tbl") + .join( + select("y").from_("tbl2"), + join_type="left outer", + join_alias="aliased", + ), + "SELECT x FROM tbl LEFT OUTER JOIN (SELECT y FROM tbl2) AS aliased", + ), + ( + lambda: select("x") + .from_("tbl") + .join(parse_one("left join x", into=exp.Join), on="a=b"), + "SELECT x FROM tbl LEFT JOIN x ON a = b", + ), + ( + lambda: select("x").from_("tbl").join("left join x", on="a=b"), + "SELECT x FROM tbl LEFT JOIN x ON a = b", + ), + ( + lambda: select("x") + .from_("tbl") + .join("select b from tbl2", on="a=b", join_type="left"), + "SELECT x FROM tbl LEFT JOIN (SELECT b FROM tbl2) ON a = b", + ), + ( + lambda: select("x") + .from_("tbl") + .join( + "select b from tbl2", + on="a=b", + join_type="left", + join_alias="aliased", + ), + "SELECT x FROM tbl LEFT JOIN (SELECT b FROM tbl2) AS aliased ON a = b", + ), + ( + lambda: select("x", "COUNT(y)") + .from_("tbl") + .group_by("x") + .having("COUNT(y) > 0"), + "SELECT x, COUNT(y) FROM tbl GROUP BY x HAVING COUNT(y) > 0", + ), + ( + lambda: select("x").from_("tbl").order_by("y"), + "SELECT x FROM tbl ORDER BY y", + ), + ( + lambda: select("x").from_("tbl").cluster_by("y"), + "SELECT x FROM tbl CLUSTER BY y", + ), + ( + lambda: select("x").from_("tbl").sort_by("y"), + "SELECT x FROM tbl SORT BY y", + ), + ( + lambda: select("x").from_("tbl").order_by("x, y DESC"), + "SELECT x FROM tbl ORDER BY x, y DESC", + ), + ( + lambda: select("x").from_("tbl").cluster_by("x, y DESC"), + "SELECT x FROM tbl CLUSTER BY x, y DESC", + ), + ( + lambda: select("x").from_("tbl").sort_by("x, y DESC"), + "SELECT x FROM tbl SORT BY x, y DESC", + ), + ( + lambda: select("x", "y", "z", "a") + .from_("tbl") + .order_by("x, y", "z") + .order_by("a"), + "SELECT x, y, z, a FROM tbl ORDER BY x, y, z, a", + ), + ( + lambda: select("x", "y", "z", "a") + .from_("tbl") + .cluster_by("x, y", "z") + .cluster_by("a"), + "SELECT x, y, z, a FROM tbl CLUSTER BY x, y, z, a", + ), + ( + lambda: select("x", "y", "z", "a") + .from_("tbl") + .sort_by("x, y", "z") + .sort_by("a"), + "SELECT x, y, z, a FROM tbl SORT BY x, y, z, a", + ), + (lambda: select("x").from_("tbl").limit(10), "SELECT x FROM tbl LIMIT 10"), + ( + lambda: select("x").from_("tbl").offset(10), + "SELECT x FROM tbl OFFSET 10", + ), + ( + lambda: select("x").from_("tbl").with_("tbl", as_="SELECT x FROM tbl2"), + "WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl", + ), + ( + lambda: select("x") + .from_("tbl") + .with_("tbl", as_="SELECT x FROM tbl2", recursive=True), + "WITH RECURSIVE tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl", + ), + ( + lambda: select("x") + .from_("tbl") + .with_("tbl", as_=select("x").from_("tbl2")), + "WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl", + ), + ( + lambda: select("x") + .from_("tbl") + .with_("tbl (x, y)", as_=select("x", "y").from_("tbl2")), + "WITH tbl(x, y) AS (SELECT x, y FROM tbl2) SELECT x FROM tbl", + ), + ( + lambda: select("x") + .from_("tbl") + .with_("tbl", as_=select("x").from_("tbl2")) + .with_("tbl2", as_=select("x").from_("tbl3")), + "WITH tbl AS (SELECT x FROM tbl2), tbl2 AS (SELECT x FROM tbl3) SELECT x FROM tbl", + ), + ( + lambda: select("x") + .from_("tbl") + .with_("tbl", as_=select("x", "y").from_("tbl2")) + .select("y"), + "WITH tbl AS (SELECT x, y FROM tbl2) SELECT x, y FROM tbl", + ), + ( + lambda: select("x") + .with_("tbl", as_=select("x").from_("tbl2")) + .from_("tbl"), + "WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl", + ), + ( + lambda: select("x") + .with_("tbl", as_=select("x").from_("tbl2")) + .from_("tbl") + .group_by("x"), + "WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl GROUP BY x", + ), + ( + lambda: select("x") + .with_("tbl", as_=select("x").from_("tbl2")) + .from_("tbl") + .order_by("x"), + "WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl ORDER BY x", + ), + ( + lambda: select("x") + .with_("tbl", as_=select("x").from_("tbl2")) + .from_("tbl") + .limit(10), + "WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl LIMIT 10", + ), + ( + lambda: select("x") + .with_("tbl", as_=select("x").from_("tbl2")) + .from_("tbl") + .offset(10), + "WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl OFFSET 10", + ), + ( + lambda: select("x") + .with_("tbl", as_=select("x").from_("tbl2")) + .from_("tbl") + .join("tbl3"), + "WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl JOIN tbl3", + ), + ( + lambda: select("x") + .with_("tbl", as_=select("x").from_("tbl2")) + .from_("tbl") + .distinct(), + "WITH tbl AS (SELECT x FROM tbl2) SELECT DISTINCT x FROM tbl", + ), + ( + lambda: select("x") + .with_("tbl", as_=select("x").from_("tbl2")) + .from_("tbl") + .where("x > 10"), + "WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl WHERE x > 10", + ), + ( + lambda: select("x") + .with_("tbl", as_=select("x").from_("tbl2")) + .from_("tbl") + .having("x > 20"), + "WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl HAVING x > 20", + ), + (lambda: select("x").from_("tbl").subquery(), "(SELECT x FROM tbl)"), + ( + lambda: select("x").from_("tbl").subquery("y"), + "(SELECT x FROM tbl) AS y", + ), + ( + lambda: select("x").from_(select("x").from_("tbl").subquery()), + "SELECT x FROM (SELECT x FROM tbl)", + ), + (lambda: from_("tbl").select("x"), "SELECT x FROM tbl"), + ( + lambda: parse_one("SELECT a FROM tbl") + .assert_is(exp.Select) + .select("b"), + "SELECT a, b FROM tbl", + ), + ( + lambda: parse_one("SELECT * FROM y").assert_is(exp.Select).ctas("x"), + "CREATE TABLE x AS SELECT * FROM y", + ), + ( + lambda: parse_one("SELECT * FROM y") + .assert_is(exp.Select) + .ctas("foo.x", properties={"format": "parquet", "y": "2"}), + "CREATE TABLE foo.x STORED AS PARQUET TBLPROPERTIES ('y' = '2') AS SELECT * FROM y", + "hive", + ), + (lambda: and_("x=1", "y=1"), "x = 1 AND y = 1"), + (lambda: condition("x").and_("y['a']").and_("1"), "(x AND y['a']) AND 1"), + (lambda: condition("x=1").and_("y=1"), "x = 1 AND y = 1"), + (lambda: and_("x=1", "y=1", "z=1"), "x = 1 AND y = 1 AND z = 1"), + (lambda: condition("x=1").and_("y=1", "z=1"), "x = 1 AND y = 1 AND z = 1"), + (lambda: and_("x=1", and_("y=1", "z=1")), "x = 1 AND (y = 1 AND z = 1)"), + ( + lambda: condition("x=1").and_("y=1").and_("z=1"), + "(x = 1 AND y = 1) AND z = 1", + ), + (lambda: or_(and_("x=1", "y=1"), "z=1"), "(x = 1 AND y = 1) OR z = 1"), + ( + lambda: condition("x=1").and_("y=1").or_("z=1"), + "(x = 1 AND y = 1) OR z = 1", + ), + (lambda: or_("z=1", and_("x=1", "y=1")), "z = 1 OR (x = 1 AND y = 1)"), + ( + lambda: or_("z=1 OR a=1", and_("x=1", "y=1")), + "(z = 1 OR a = 1) OR (x = 1 AND y = 1)", + ), + (lambda: not_("x=1"), "NOT x = 1"), + (lambda: condition("x=1").not_(), "NOT x = 1"), + (lambda: condition("x=1").and_("y=1").not_(), "NOT (x = 1 AND y = 1)"), + ( + lambda: select("*").from_("x").where(condition("y=1").and_("z=1")), + "SELECT * FROM x WHERE y = 1 AND z = 1", + ), + ( + lambda: exp.subquery("select x from tbl", "foo") + .select("x") + .where("x > 0"), + "SELECT x FROM (SELECT x FROM tbl) AS foo WHERE x > 0", + ), + ( + lambda: exp.subquery( + "select x from tbl UNION select x from bar", "unioned" + ).select("x"), + "SELECT x FROM (SELECT x FROM tbl UNION SELECT x FROM bar) AS unioned", + ), + ]: + with self.subTest(sql): + self.assertEqual(expression().sql(dialect[0] if dialect else None), sql) diff --git a/tests/test_diff.py b/tests/test_diff.py new file mode 100644 index 0000000..cbd53b3 --- /dev/null +++ b/tests/test_diff.py @@ -0,0 +1,137 @@ +import unittest + +from sqlglot import parse_one +from sqlglot.diff import Insert, Keep, Move, Remove, Update, diff +from sqlglot.expressions import Join, to_identifier + + +class TestDiff(unittest.TestCase): + def test_simple(self): + self._validate_delta_only( + diff(parse_one("SELECT a + b"), parse_one("SELECT a - b")), + [ + Remove(parse_one("a + b")), # the Add node + Insert(parse_one("a - b")), # the Sub node + ], + ) + + self._validate_delta_only( + diff(parse_one("SELECT a, b, c"), parse_one("SELECT a, c")), + [ + Remove(to_identifier("b", quoted=False)), # the Identifier node + Remove(parse_one("b")), # the Column node + ], + ) + + self._validate_delta_only( + diff(parse_one("SELECT a, b"), parse_one("SELECT a, b, c")), + [ + Insert(to_identifier("c", quoted=False)), # the Identifier node + Insert(parse_one("c")), # the Column node + ], + ) + + self._validate_delta_only( + diff( + parse_one("SELECT a FROM table_one"), + parse_one("SELECT a FROM table_two"), + ), + [ + Update( + to_identifier("table_one", quoted=False), + to_identifier("table_two", quoted=False), + ), # the Identifier node + ], + ) + + def test_node_position_changed(self): + self._validate_delta_only( + diff(parse_one("SELECT a, b, c"), parse_one("SELECT c, a, b")), + [ + Move(parse_one("c")), # the Column node + ], + ) + + self._validate_delta_only( + diff(parse_one("SELECT a + b"), parse_one("SELECT b + a")), + [ + Move(parse_one("a")), # the Column node + ], + ) + + self._validate_delta_only( + diff(parse_one("SELECT aaaa AND bbbb"), parse_one("SELECT bbbb AND aaaa")), + [ + Move(parse_one("aaaa")), # the Column node + ], + ) + + self._validate_delta_only( + diff( + parse_one("SELECT aaaa OR bbbb OR cccc"), + parse_one("SELECT cccc OR bbbb OR aaaa"), + ), + [ + Move(parse_one("aaaa")), # the Column node + Move(parse_one("cccc")), # the Column node + ], + ) + + def test_cte(self): + expr_src = """ + WITH + cte1 AS (SELECT a, b, LOWER(c) AS c FROM table_one WHERE d = 'filter'), + cte2 AS (SELECT d, e, f FROM table_two) + SELECT a, b, d, e FROM cte1 JOIN cte2 ON f = c + """ + expr_tgt = """ + WITH + cte1 AS (SELECT a, b, c FROM table_one WHERE d = 'different_filter'), + cte2 AS (SELECT d, e, f FROM table_two) + SELECT a, b, d, e FROM cte1 JOIN cte2 ON f = c + """ + + self._validate_delta_only( + diff(parse_one(expr_src), parse_one(expr_tgt)), + [ + Remove(parse_one("LOWER(c) AS c")), # the Alias node + Remove(to_identifier("c", quoted=False)), # the Identifier node + Remove(parse_one("LOWER(c)")), # the Lower node + Remove(parse_one("'filter'")), # the Literal node + Insert(parse_one("'different_filter'")), # the Literal node + ], + ) + + def test_join(self): + expr_src = "SELECT a, b FROM t1 LEFT JOIN t2 ON t1.key = t2.key" + expr_tgt = "SELECT a, b FROM t1 RIGHT JOIN t2 ON t1.key = t2.key" + + changes = diff(parse_one(expr_src), parse_one(expr_tgt)) + changes = _delta_only(changes) + + self.assertEqual(len(changes), 2) + self.assertTrue(isinstance(changes[0], Remove)) + self.assertTrue(isinstance(changes[1], Insert)) + self.assertTrue(all(isinstance(c.expression, Join) for c in changes)) + + def test_window_functions(self): + expr_src = parse_one("SELECT ROW_NUMBER() OVER (PARTITION BY a ORDER BY b)") + expr_tgt = parse_one("SELECT RANK() OVER (PARTITION BY a ORDER BY b)") + + self._validate_delta_only(diff(expr_src, expr_src), []) + + self._validate_delta_only( + diff(expr_src, expr_tgt), + [ + Remove(parse_one("ROW_NUMBER()")), # the Anonymous node + Insert(parse_one("RANK()")), # the Anonymous node + ], + ) + + def _validate_delta_only(self, actual_diff, expected_delta): + actual_delta = _delta_only(actual_diff) + self.assertEqual(set(actual_delta), set(expected_delta)) + + +def _delta_only(changes): + return [d for d in changes if not isinstance(d, Keep)] diff --git a/tests/test_docs.py b/tests/test_docs.py new file mode 100644 index 0000000..95aa814 --- /dev/null +++ b/tests/test_docs.py @@ -0,0 +1,30 @@ +import doctest +import inspect +import unittest + +import sqlglot +import sqlglot.optimizer +import sqlglot.transforms + + +def load_tests(loader, tests, ignore): + """ + This finds and runs all the doctests + """ + + modules = { + mod + for module in [sqlglot, sqlglot.transforms, sqlglot.optimizer] + for _, mod in inspect.getmembers(module, inspect.ismodule) + } + + assert len(modules) >= 20 + + for module in modules: + tests.addTests(doctest.DocTestSuite(module)) + + return tests + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_executor.py b/tests/test_executor.py new file mode 100644 index 0000000..9afa225 --- /dev/null +++ b/tests/test_executor.py @@ -0,0 +1,72 @@ +import unittest + +import duckdb +import pandas as pd +from pandas.testing import assert_frame_equal + +from sqlglot import exp, parse_one +from sqlglot.executor import execute +from sqlglot.executor.python import Python +from tests.helpers import FIXTURES_DIR, TPCH_SCHEMA, load_sql_fixture_pairs + +DIR = FIXTURES_DIR + "/optimizer/tpc-h/" + + +class TestExecutor(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.conn = duckdb.connect() + + for table in TPCH_SCHEMA: + cls.conn.execute( + f""" + CREATE VIEW {table} AS + SELECT * + FROM READ_CSV_AUTO('{DIR}{table}.csv.gz') + """ + ) + + cls.cache = {} + cls.sqls = [ + (sql, expected) + for _, sql, expected in load_sql_fixture_pairs("optimizer/tpc-h/tpc-h.sql") + ] + + @classmethod + def tearDownClass(cls): + cls.conn.close() + + def cached_execute(self, sql): + if sql not in self.cache: + self.cache[sql] = self.conn.execute(sql).fetchdf() + return self.cache[sql] + + def rename_anonymous(self, source, target): + for i, column in enumerate(source.columns): + if "_col_" in column: + source.rename(columns={column: target.columns[i]}, inplace=True) + + def test_py_dialect(self): + self.assertEqual(Python().generate(parse_one("'x '''")), r"'x \''") + + def test_optimized_tpch(self): + for sql, optimized in self.sqls[0:20]: + a = self.cached_execute(sql) + b = self.conn.execute(optimized).fetchdf() + self.rename_anonymous(b, a) + assert_frame_equal(a, b) + + def test_execute_tpch(self): + def to_csv(expression): + if isinstance(expression, exp.Table): + return parse_one( + f"READ_CSV('{DIR}{expression.name}.csv.gz', 'delimiter', '|') AS {expression.name}" + ) + return expression + + for sql, _ in self.sqls[0:3]: + a = self.cached_execute(sql) + sql = parse_one(sql).transform(to_csv).sql(pretty=True) + table = execute(sql, TPCH_SCHEMA) + b = pd.DataFrame(table.rows, columns=table.columns) + assert_frame_equal(a, b, check_dtype=False) diff --git a/tests/test_expressions.py b/tests/test_expressions.py new file mode 100644 index 0000000..eaef022 --- /dev/null +++ b/tests/test_expressions.py @@ -0,0 +1,415 @@ +import unittest + +from sqlglot import alias, exp, parse_one + + +class TestExpressions(unittest.TestCase): + def test_arg_key(self): + self.assertEqual(parse_one("sum(1)").find(exp.Literal).arg_key, "this") + + def test_depth(self): + self.assertEqual(parse_one("x(1)").find(exp.Literal).depth, 1) + + def test_eq(self): + self.assertEqual(parse_one("`a`", read="hive"), parse_one('"a"')) + self.assertEqual(parse_one("`a`", read="hive"), parse_one('"a" ')) + self.assertEqual(parse_one("`a`.b", read="hive"), parse_one('"a"."b"')) + self.assertEqual(parse_one("select a, b+1"), parse_one("SELECT a, b + 1")) + self.assertEqual(parse_one("`a`.`b`.`c`", read="hive"), parse_one("a.b.c")) + self.assertNotEqual(parse_one("a.b.c.d", read="hive"), parse_one("a.b.c")) + self.assertEqual(parse_one("a.b.c.d", read="hive"), parse_one("a.b.c.d")) + self.assertEqual(parse_one("a + b * c - 1.0"), parse_one("a+b*c-1.0")) + self.assertNotEqual(parse_one("a + b * c - 1.0"), parse_one("a + b * c + 1.0")) + self.assertEqual(parse_one("a as b"), parse_one("a AS b")) + self.assertNotEqual(parse_one("a as b"), parse_one("a")) + self.assertEqual( + parse_one("ROW() OVER(Partition by y)"), + parse_one("ROW() OVER (partition BY y)"), + ) + self.assertEqual( + parse_one("TO_DATE(x)", read="hive"), parse_one("ts_or_ds_to_date(x)") + ) + + def test_find(self): + expression = parse_one("CREATE TABLE x STORED AS PARQUET AS SELECT * FROM y") + self.assertTrue(expression.find(exp.Create)) + self.assertFalse(expression.find(exp.Group)) + self.assertEqual( + [table.name for table in expression.find_all(exp.Table)], + ["x", "y"], + ) + + def test_find_all(self): + expression = parse_one( + """ + SELECT * + FROM ( + SELECT b.* + FROM a.b b + ) x + JOIN ( + SELECT c.foo + FROM a.c c + WHERE foo = 1 + ) y + ON x.c = y.foo + CROSS JOIN ( + SELECT * + FROM ( + SELECT d.bar + FROM d + ) nested + ) z + ON x.c = y.foo + """ + ) + + self.assertEqual( + [table.name for table in expression.find_all(exp.Table)], + ["b", "c", "d"], + ) + + expression = parse_one("select a + b + c + d") + + self.assertEqual( + [column.name for column in expression.find_all(exp.Column)], + ["d", "c", "a", "b"], + ) + self.assertEqual( + [column.name for column in expression.find_all(exp.Column, bfs=False)], + ["a", "b", "c", "d"], + ) + + def test_find_ancestor(self): + column = parse_one("select * from foo where (a + 1 > 2)").find(exp.Column) + self.assertIsInstance(column, exp.Column) + self.assertIsInstance(column.parent_select, exp.Select) + self.assertIsNone(column.find_ancestor(exp.Join)) + + def test_alias_or_name(self): + expression = parse_one( + "SELECT a, b AS B, c + d AS e, *, 'zz', 'zz' AS z FROM foo as bar, baz" + ) + self.assertEqual( + [e.alias_or_name for e in expression.expressions], + ["a", "B", "e", "*", "zz", "z"], + ) + self.assertEqual( + [e.alias_or_name for e in expression.args["from"].expressions], + ["bar", "baz"], + ) + + expression = parse_one( + """ + WITH first AS (SELECT * FROM foo), + second AS (SELECT * FROM bar) + SELECT * FROM first, second, (SELECT * FROM baz) AS third + """ + ) + + self.assertEqual( + [e.alias_or_name for e in expression.args["with"].expressions], + ["first", "second"], + ) + + self.assertEqual( + [e.alias_or_name for e in expression.args["from"].expressions], + ["first", "second", "third"], + ) + + def test_named_selects(self): + expression = parse_one( + "SELECT a, b AS B, c + d AS e, *, 'zz', 'zz' AS z FROM foo as bar, baz" + ) + self.assertEqual(expression.named_selects, ["a", "B", "e", "*", "zz", "z"]) + + expression = parse_one( + """ + WITH first AS (SELECT * FROM foo) + SELECT foo.bar, foo.baz as bazz, SUM(x) FROM first + """ + ) + self.assertEqual(expression.named_selects, ["bar", "bazz"]) + + expression = parse_one( + """ + SELECT foo, bar FROM first + UNION SELECT "ss" as foo, bar FROM second + UNION ALL SELECT foo, bazz FROM third + """ + ) + self.assertEqual(expression.named_selects, ["foo", "bar"]) + + def test_selects(self): + expression = parse_one("SELECT FROM x") + self.assertEqual(expression.selects, []) + + expression = parse_one("SELECT a FROM x") + self.assertEqual([s.sql() for s in expression.selects], ["a"]) + + expression = parse_one("SELECT a, b FROM x") + self.assertEqual([s.sql() for s in expression.selects], ["a", "b"]) + + def test_alias_column_names(self): + expression = parse_one("SELECT * FROM (SELECT * FROM x) AS y") + subquery = expression.find(exp.Subquery) + self.assertEqual(subquery.alias_column_names, []) + + expression = parse_one("SELECT * FROM (SELECT * FROM x) AS y(a)") + subquery = expression.find(exp.Subquery) + self.assertEqual(subquery.alias_column_names, ["a"]) + + expression = parse_one("SELECT * FROM (SELECT * FROM x) AS y(a, b)") + subquery = expression.find(exp.Subquery) + self.assertEqual(subquery.alias_column_names, ["a", "b"]) + + expression = parse_one("WITH y AS (SELECT * FROM x) SELECT * FROM y") + cte = expression.find(exp.CTE) + self.assertEqual(cte.alias_column_names, []) + + expression = parse_one("WITH y(a, b) AS (SELECT * FROM x) SELECT * FROM y") + cte = expression.find(exp.CTE) + self.assertEqual(cte.alias_column_names, ["a", "b"]) + + def test_ctes(self): + expression = parse_one("SELECT a FROM x") + self.assertEqual(expression.ctes, []) + + expression = parse_one("WITH x AS (SELECT a FROM y) SELECT a FROM x") + self.assertEqual([s.sql() for s in expression.ctes], ["x AS (SELECT a FROM y)"]) + + def test_hash(self): + self.assertEqual( + { + parse_one("select a.b"), + parse_one("1+2"), + parse_one('"a".b'), + parse_one("a.b.c.d"), + }, + { + parse_one("select a.b"), + parse_one("1+2"), + parse_one('"a"."b"'), + parse_one("a.b.c.d"), + }, + ) + + def test_sql(self): + self.assertEqual(parse_one("x + y * 2").sql(), "x + y * 2") + self.assertEqual( + parse_one('select "x"').sql(dialect="hive", pretty=True), "SELECT\n `x`" + ) + self.assertEqual( + parse_one("X + y").sql(identify=True, normalize=True), '"x" + "y"' + ) + self.assertEqual( + parse_one("SUM(X)").sql(identify=True, normalize=True), 'SUM("x")' + ) + + def test_transform_with_arguments(self): + expression = parse_one("a") + + def fun(node, alias_=True): + if alias_: + return parse_one("a AS a") + return node + + transformed_expression = expression.transform(fun) + self.assertEqual(transformed_expression.sql(dialect="presto"), "a AS a") + + transformed_expression_2 = expression.transform(fun, alias_=False) + self.assertEqual(transformed_expression_2.sql(dialect="presto"), "a") + + def test_transform_simple(self): + expression = parse_one("IF(a > 0, a, b)") + + def fun(node): + if isinstance(node, exp.Column) and node.name == "a": + return parse_one("c - 2") + return node + + actual_expression_1 = expression.transform(fun) + self.assertEqual( + actual_expression_1.sql(dialect="presto"), "IF(c - 2 > 0, c - 2, b)" + ) + self.assertIsNot(actual_expression_1, expression) + + actual_expression_2 = expression.transform(fun, copy=False) + self.assertEqual( + actual_expression_2.sql(dialect="presto"), "IF(c - 2 > 0, c - 2, b)" + ) + self.assertIs(actual_expression_2, expression) + + with self.assertRaises(ValueError): + parse_one("a").transform(lambda n: None) + + def test_transform_no_infinite_recursion(self): + expression = parse_one("a") + + def fun(node): + if isinstance(node, exp.Column) and node.name == "a": + return parse_one("FUN(a)") + return node + + self.assertEqual(expression.transform(fun).sql(), "FUN(a)") + + def test_transform_multiple_children(self): + expression = parse_one("SELECT * FROM x") + + def fun(node): + if isinstance(node, exp.Star): + return [parse_one(c) for c in ["a", "b"]] + return node + + self.assertEqual(expression.transform(fun).sql(), "SELECT a, b FROM x") + + def test_replace(self): + expression = parse_one("SELECT a, b FROM x") + expression.find(exp.Column).replace(parse_one("c")) + self.assertEqual(expression.sql(), "SELECT c, b FROM x") + expression.find(exp.Table).replace(parse_one("y")) + self.assertEqual(expression.sql(), "SELECT c, b FROM y") + + def test_walk(self): + expression = parse_one("SELECT * FROM (SELECT * FROM x)") + self.assertEqual(len(list(expression.walk())), 9) + self.assertEqual(len(list(expression.walk(bfs=False))), 9) + self.assertTrue( + all(isinstance(e, exp.Expression) for e, _, _ in expression.walk()) + ) + self.assertTrue( + all(isinstance(e, exp.Expression) for e, _, _ in expression.walk(bfs=False)) + ) + + def test_functions(self): + self.assertIsInstance(parse_one("ABS(a)"), exp.Abs) + self.assertIsInstance(parse_one("APPROX_DISTINCT(a)"), exp.ApproxDistinct) + self.assertIsInstance(parse_one("ARRAY(a)"), exp.Array) + self.assertIsInstance(parse_one("ARRAY_AGG(a)"), exp.ArrayAgg) + self.assertIsInstance(parse_one("ARRAY_CONTAINS(a, 'a')"), exp.ArrayContains) + self.assertIsInstance(parse_one("ARRAY_SIZE(a)"), exp.ArraySize) + self.assertIsInstance(parse_one("AVG(a)"), exp.Avg) + self.assertIsInstance(parse_one("CEIL(a)"), exp.Ceil) + self.assertIsInstance(parse_one("CEILING(a)"), exp.Ceil) + self.assertIsInstance(parse_one("COALESCE(a, b)"), exp.Coalesce) + self.assertIsInstance(parse_one("COUNT(a)"), exp.Count) + self.assertIsInstance(parse_one("DATE_ADD(a, 1)"), exp.DateAdd) + self.assertIsInstance(parse_one("DATE_DIFF(a, 2)"), exp.DateDiff) + self.assertIsInstance(parse_one("DATE_STR_TO_DATE(a)"), exp.DateStrToDate) + self.assertIsInstance(parse_one("DAY(a)"), exp.Day) + self.assertIsInstance(parse_one("EXP(a)"), exp.Exp) + self.assertIsInstance(parse_one("FLOOR(a)"), exp.Floor) + self.assertIsInstance(parse_one("GREATEST(a, b)"), exp.Greatest) + self.assertIsInstance(parse_one("IF(a, b, c)"), exp.If) + self.assertIsInstance(parse_one("INITCAP(a)"), exp.Initcap) + self.assertIsInstance(parse_one("JSON_EXTRACT(a, '$.name')"), exp.JSONExtract) + self.assertIsInstance( + parse_one("JSON_EXTRACT_SCALAR(a, '$.name')"), exp.JSONExtractScalar + ) + self.assertIsInstance(parse_one("LEAST(a, b)"), exp.Least) + self.assertIsInstance(parse_one("LN(a)"), exp.Ln) + self.assertIsInstance(parse_one("LOG10(a)"), exp.Log10) + self.assertIsInstance(parse_one("MAX(a)"), exp.Max) + self.assertIsInstance(parse_one("MIN(a)"), exp.Min) + self.assertIsInstance(parse_one("MONTH(a)"), exp.Month) + self.assertIsInstance(parse_one("POW(a, 2)"), exp.Pow) + self.assertIsInstance(parse_one("POWER(a, 2)"), exp.Pow) + self.assertIsInstance(parse_one("QUANTILE(a, 0.90)"), exp.Quantile) + self.assertIsInstance(parse_one("REGEXP_LIKE(a, 'test')"), exp.RegexpLike) + self.assertIsInstance(parse_one("REGEXP_SPLIT(a, 'test')"), exp.RegexpSplit) + self.assertIsInstance(parse_one("ROUND(a)"), exp.Round) + self.assertIsInstance(parse_one("ROUND(a, 2)"), exp.Round) + self.assertIsInstance(parse_one("SPLIT(a, 'test')"), exp.Split) + self.assertIsInstance(parse_one("STR_POSITION(a, 'test')"), exp.StrPosition) + self.assertIsInstance(parse_one("STR_TO_UNIX(a, 'format')"), exp.StrToUnix) + self.assertIsInstance(parse_one("STRUCT_EXTRACT(a, 'test')"), exp.StructExtract) + self.assertIsInstance(parse_one("SUM(a)"), exp.Sum) + self.assertIsInstance(parse_one("SQRT(a)"), exp.Sqrt) + self.assertIsInstance(parse_one("STDDEV(a)"), exp.Stddev) + self.assertIsInstance(parse_one("STDDEV_POP(a)"), exp.StddevPop) + self.assertIsInstance(parse_one("STDDEV_SAMP(a)"), exp.StddevSamp) + self.assertIsInstance(parse_one("TIME_TO_STR(a, 'format')"), exp.TimeToStr) + self.assertIsInstance(parse_one("TIME_TO_TIME_STR(a)"), exp.Cast) + self.assertIsInstance(parse_one("TIME_TO_UNIX(a)"), exp.TimeToUnix) + self.assertIsInstance(parse_one("TIME_STR_TO_DATE(a)"), exp.TimeStrToDate) + self.assertIsInstance(parse_one("TIME_STR_TO_TIME(a)"), exp.TimeStrToTime) + self.assertIsInstance(parse_one("TIME_STR_TO_UNIX(a)"), exp.TimeStrToUnix) + self.assertIsInstance(parse_one("TS_OR_DS_ADD(a, 1, 'day')"), exp.TsOrDsAdd) + self.assertIsInstance(parse_one("TS_OR_DS_TO_DATE(a)"), exp.TsOrDsToDate) + self.assertIsInstance(parse_one("TS_OR_DS_TO_DATE_STR(a)"), exp.Substring) + self.assertIsInstance(parse_one("UNIX_TO_STR(a, 'format')"), exp.UnixToStr) + self.assertIsInstance(parse_one("UNIX_TO_TIME(a)"), exp.UnixToTime) + self.assertIsInstance(parse_one("UNIX_TO_TIME_STR(a)"), exp.UnixToTimeStr) + self.assertIsInstance(parse_one("VARIANCE(a)"), exp.Variance) + self.assertIsInstance(parse_one("VARIANCE_POP(a)"), exp.VariancePop) + self.assertIsInstance(parse_one("YEAR(a)"), exp.Year) + + def test_column(self): + dot = parse_one("a.b.c") + column = dot.this + self.assertEqual(column.table, "a") + self.assertEqual(column.name, "b") + self.assertEqual(dot.text("expression"), "c") + + column = parse_one("a") + self.assertEqual(column.name, "a") + self.assertEqual(column.table, "") + + fields = parse_one("a.b.c.d") + self.assertIsInstance(fields, exp.Dot) + self.assertEqual(fields.text("expression"), "d") + self.assertEqual(fields.this.text("expression"), "c") + column = fields.find(exp.Column) + self.assertEqual(column.name, "b") + self.assertEqual(column.table, "a") + + column = parse_one("a[0].b") + self.assertIsInstance(column, exp.Dot) + self.assertIsInstance(column.this, exp.Bracket) + self.assertIsInstance(column.this.this, exp.Column) + + column = parse_one("a.*") + self.assertIsInstance(column, exp.Column) + self.assertIsInstance(column.this, exp.Star) + self.assertIsInstance(column.args["table"], exp.Identifier) + self.assertEqual(column.table, "a") + + self.assertIsInstance(parse_one("*"), exp.Star) + + def test_text(self): + column = parse_one("a.b.c") + self.assertEqual(column.text("expression"), "c") + self.assertEqual(column.text("y"), "") + self.assertEqual(parse_one("select * from x.y").find(exp.Table).text("db"), "x") + self.assertEqual(parse_one("select *").text("this"), "") + self.assertEqual(parse_one("1 + 1").text("this"), "1") + self.assertEqual(parse_one("'a'").text("this"), "a") + + def test_alias(self): + self.assertEqual(alias("foo", "bar").sql(), "foo AS bar") + self.assertEqual(alias("foo", "bar-1").sql(), 'foo AS "bar-1"') + self.assertEqual(alias("foo", "bar_1").sql(), "foo AS bar_1") + self.assertEqual(alias("foo * 2", "2bar").sql(), 'foo * 2 AS "2bar"') + self.assertEqual(alias('"foo"', "_bar").sql(), '"foo" AS "_bar"') + self.assertEqual(alias("foo", "bar", quoted=True).sql(), 'foo AS "bar"') + + def test_unit(self): + unit = parse_one("timestamp_trunc(current_timestamp, week(thursday))") + self.assertIsNotNone(unit.find(exp.CurrentTimestamp)) + week = unit.find(exp.Week) + self.assertEqual(week.this, exp.Var(this="thursday")) + + def test_identifier(self): + self.assertTrue(exp.to_identifier('"x"').quoted) + self.assertFalse(exp.to_identifier("x").quoted) + + def test_function_normalizer(self): + self.assertEqual( + parse_one("HELLO()").sql(normalize_functions="lower"), "hello()" + ) + self.assertEqual( + parse_one("hello()").sql(normalize_functions="upper"), "HELLO()" + ) + self.assertEqual(parse_one("heLLO()").sql(normalize_functions=None), "heLLO()") + self.assertEqual(parse_one("SUM(x)").sql(normalize_functions="lower"), "sum(x)") + self.assertEqual(parse_one("sum(x)").sql(normalize_functions="upper"), "SUM(x)") diff --git a/tests/test_generator.py b/tests/test_generator.py new file mode 100644 index 0000000..d64a818 --- /dev/null +++ b/tests/test_generator.py @@ -0,0 +1,30 @@ +import unittest + +from sqlglot.expressions import Func +from sqlglot.parser import Parser +from sqlglot.tokens import Tokenizer + + +class TestGenerator(unittest.TestCase): + def test_fallback_function_sql(self): + class SpecialUDF(Func): + arg_types = {"a": True, "b": False} + + class NewParser(Parser): + FUNCTIONS = SpecialUDF.default_parser_mappings() + + tokens = Tokenizer().tokenize("SELECT SPECIAL_UDF(a) FROM x") + expression = NewParser().parse(tokens)[0] + self.assertEqual(expression.sql(), "SELECT SPECIAL_UDF(a) FROM x") + + def test_fallback_function_var_args_sql(self): + class SpecialUDF(Func): + arg_types = {"a": True, "expressions": False} + is_var_len_args = True + + class NewParser(Parser): + FUNCTIONS = SpecialUDF.default_parser_mappings() + + tokens = Tokenizer().tokenize("SELECT SPECIAL_UDF(a, b, c, d + 1) FROM x") + expression = NewParser().parse(tokens)[0] + self.assertEqual(expression.sql(), "SELECT SPECIAL_UDF(a, b, c, d + 1) FROM x") diff --git a/tests/test_helper.py b/tests/test_helper.py new file mode 100644 index 0000000..d37c03a --- /dev/null +++ b/tests/test_helper.py @@ -0,0 +1,31 @@ +import unittest + +from sqlglot.helper import tsort + + +class TestHelper(unittest.TestCase): + def test_tsort(self): + self.assertEqual(tsort({"a": []}), ["a"]) + self.assertEqual(tsort({"a": ["b", "b"]}), ["b", "a"]) + self.assertEqual(tsort({"a": ["b"]}), ["b", "a"]) + self.assertEqual(tsort({"a": ["c"], "b": [], "c": []}), ["c", "a", "b"]) + self.assertEqual( + tsort( + { + "a": ["b", "c"], + "b": ["c"], + "c": [], + "d": ["a"], + } + ), + ["c", "b", "a", "d"], + ) + + with self.assertRaises(ValueError): + tsort( + { + "a": ["b", "c"], + "b": ["a"], + "c": [], + } + ) diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py new file mode 100644 index 0000000..40540b3 --- /dev/null +++ b/tests/test_optimizer.py @@ -0,0 +1,276 @@ +import unittest + +from sqlglot import optimizer, parse_one, table +from sqlglot.errors import OptimizeError +from sqlglot.optimizer.schema import MappingSchema, ensure_schema +from sqlglot.optimizer.scope import traverse_scope +from tests.helpers import TPCH_SCHEMA, load_sql_fixture_pairs, load_sql_fixtures + + +class TestOptimizer(unittest.TestCase): + maxDiff = None + + def setUp(self): + self.schema = { + "x": { + "a": "INT", + "b": "INT", + }, + "y": { + "b": "INT", + "c": "INT", + }, + "z": { + "b": "INT", + "c": "INT", + }, + } + + def check_file(self, file, func, pretty=False, **kwargs): + for meta, sql, expected in load_sql_fixture_pairs(f"optimizer/{file}.sql"): + dialect = meta.get("dialect") + with self.subTest(sql): + self.assertEqual( + func(parse_one(sql, read=dialect), **kwargs).sql( + pretty=pretty, dialect=dialect + ), + expected, + ) + + def test_optimize(self): + schema = { + "x": {"a": "INT", "b": "INT"}, + "y": {"a": "INT", "b": "INT"}, + "z": {"a": "INT", "c": "INT"}, + } + + self.check_file("optimizer", optimizer.optimize, pretty=True, schema=schema) + + def test_isolate_table_selects(self): + self.check_file( + "isolate_table_selects", + optimizer.isolate_table_selects.isolate_table_selects, + ) + + def test_qualify_tables(self): + self.check_file( + "qualify_tables", + optimizer.qualify_tables.qualify_tables, + db="db", + catalog="c", + ) + + def test_normalize(self): + self.assertEqual( + optimizer.normalize.normalize( + parse_one("x AND (y OR z)"), + dnf=True, + ).sql(), + "(x AND y) OR (x AND z)", + ) + + self.check_file( + "normalize", + optimizer.normalize.normalize, + ) + + def test_qualify_columns(self): + def qualify_columns(expression, **kwargs): + expression = optimizer.qualify_tables.qualify_tables(expression) + expression = optimizer.qualify_columns.qualify_columns(expression, **kwargs) + return expression + + self.check_file("qualify_columns", qualify_columns, schema=self.schema) + + def test_qualify_columns__invalid(self): + for sql in load_sql_fixtures("optimizer/qualify_columns__invalid.sql"): + with self.subTest(sql): + with self.assertRaises(OptimizeError): + optimizer.qualify_columns.qualify_columns( + parse_one(sql), schema=self.schema + ) + + def test_quote_identities(self): + self.check_file("quote_identities", optimizer.quote_identities.quote_identities) + + def test_pushdown_projection(self): + def pushdown_projections(expression, **kwargs): + expression = optimizer.qualify_tables.qualify_tables(expression) + expression = optimizer.qualify_columns.qualify_columns(expression, **kwargs) + expression = optimizer.pushdown_projections.pushdown_projections(expression) + return expression + + self.check_file( + "pushdown_projections", pushdown_projections, schema=self.schema + ) + + def test_simplify(self): + self.check_file("simplify", optimizer.simplify.simplify) + + def test_unnest_subqueries(self): + self.check_file( + "unnest_subqueries", + optimizer.unnest_subqueries.unnest_subqueries, + pretty=True, + ) + + def test_pushdown_predicates(self): + self.check_file( + "pushdown_predicates", optimizer.pushdown_predicates.pushdown_predicates + ) + + def test_expand_multi_table_selects(self): + self.check_file( + "expand_multi_table_selects", + optimizer.expand_multi_table_selects.expand_multi_table_selects, + ) + + def test_optimize_joins(self): + self.check_file( + "optimize_joins", + optimizer.optimize_joins.optimize_joins, + ) + + def test_eliminate_subqueries(self): + self.check_file( + "eliminate_subqueries", + optimizer.eliminate_subqueries.eliminate_subqueries, + pretty=True, + ) + + def test_tpch(self): + self.check_file( + "tpc-h/tpc-h", optimizer.optimize, schema=TPCH_SCHEMA, pretty=True + ) + + def test_schema(self): + schema = ensure_schema( + { + "x": { + "a": "uint64", + } + } + ) + self.assertEqual( + schema.column_names( + table( + "x", + ) + ), + ["a"], + ) + with self.assertRaises(ValueError): + schema.column_names(table("x", db="db", catalog="c")) + with self.assertRaises(ValueError): + schema.column_names(table("x", db="db")) + with self.assertRaises(ValueError): + schema.column_names(table("x2")) + + schema = ensure_schema( + { + "db": { + "x": { + "a": "uint64", + } + } + } + ) + self.assertEqual(schema.column_names(table("x", db="db")), ["a"]) + with self.assertRaises(ValueError): + schema.column_names(table("x", db="db", catalog="c")) + with self.assertRaises(ValueError): + schema.column_names(table("x")) + with self.assertRaises(ValueError): + schema.column_names(table("x", db="db2")) + with self.assertRaises(ValueError): + schema.column_names(table("x2", db="db")) + + schema = ensure_schema( + { + "c": { + "db": { + "x": { + "a": "uint64", + } + } + } + } + ) + self.assertEqual(schema.column_names(table("x", db="db", catalog="c")), ["a"]) + with self.assertRaises(ValueError): + schema.column_names(table("x", db="db")) + with self.assertRaises(ValueError): + schema.column_names(table("x")) + with self.assertRaises(ValueError): + schema.column_names(table("x", db="db", catalog="c2")) + with self.assertRaises(ValueError): + schema.column_names(table("x", db="db2")) + with self.assertRaises(ValueError): + schema.column_names(table("x2", db="db")) + + schema = ensure_schema( + MappingSchema( + { + "x": { + "a": "uint64", + } + } + ) + ) + self.assertEqual(schema.column_names(table("x")), ["a"]) + + with self.assertRaises(OptimizeError): + ensure_schema({}) + + def test_file_schema(self): + expression = parse_one( + """ + SELECT * + FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') + """ + ) + self.assertEqual( + """ +SELECT + "_q_0"."n_nationkey" AS "n_nationkey", + "_q_0"."n_name" AS "n_name", + "_q_0"."n_regionkey" AS "n_regionkey", + "_q_0"."n_comment" AS "n_comment" +FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') AS "_q_0" +""".strip(), + optimizer.optimize(expression).sql(pretty=True), + ) + + def test_scope(self): + sql = """ + WITH q AS ( + SELECT x.b FROM x + ), r AS ( + SELECT y.b FROM y + ) + SELECT + r.b, + s.b + FROM r + JOIN ( + SELECT y.c AS b FROM y + ) s + ON s.b = r.b + WHERE s.b > (SELECT MAX(x.a) FROM x WHERE x.b = s.b) + """ + scopes = traverse_scope(parse_one(sql)) + self.assertEqual(len(scopes), 5) + self.assertEqual(scopes[0].expression.sql(), "SELECT x.b FROM x") + self.assertEqual(scopes[1].expression.sql(), "SELECT y.b FROM y") + self.assertEqual( + scopes[2].expression.sql(), "SELECT MAX(x.a) FROM x WHERE x.b = s.b" + ) + self.assertEqual(scopes[3].expression.sql(), "SELECT y.c AS b FROM y") + self.assertEqual(scopes[4].expression.sql(), parse_one(sql).sql()) + + self.assertEqual(set(scopes[4].sources), {"q", "r", "s"}) + self.assertEqual(len(scopes[4].columns), 6) + self.assertEqual(set(c.table for c in scopes[4].columns), {"r", "s"}) + self.assertEqual(scopes[4].source_columns("q"), []) + self.assertEqual(len(scopes[4].source_columns("r")), 2) + self.assertEqual(set(c.table for c in scopes[4].source_columns("r")), {"r"}) diff --git a/tests/test_parser.py b/tests/test_parser.py new file mode 100644 index 0000000..779083d --- /dev/null +++ b/tests/test_parser.py @@ -0,0 +1,195 @@ +import unittest +from unittest.mock import patch + +from sqlglot import Parser, exp, parse, parse_one +from sqlglot.errors import ErrorLevel, ParseError +from tests.helpers import assert_logger_contains + + +class TestParser(unittest.TestCase): + def test_parse_empty(self): + self.assertIsNone(parse_one("")) + + def test_parse_into(self): + self.assertIsInstance(parse_one("left join foo", into=exp.Join), exp.Join) + self.assertIsInstance(parse_one("int", into=exp.DataType), exp.DataType) + self.assertIsInstance(parse_one("array", into=exp.DataType), exp.DataType) + + def test_column(self): + columns = parse_one("select a, ARRAY[1] b, case when 1 then 1 end").find_all( + exp.Column + ) + assert len(list(columns)) == 1 + + self.assertIsNotNone(parse_one("date").find(exp.Column)) + + def test_table(self): + tables = [ + t.sql() for t in parse_one("select * from a, b.c, .d").find_all(exp.Table) + ] + self.assertEqual(tables, ["a", "b.c", "d"]) + + def test_select(self): + self.assertIsNotNone( + parse_one("select * from (select 1) x order by x.y").args["order"] + ) + self.assertIsNotNone( + parse_one("select * from x where a = (select 1) order by x.y").args["order"] + ) + self.assertEqual( + len(parse_one("select * from (select 1) x cross join y").args["joins"]), 1 + ) + + def test_command(self): + expressions = parse("SET x = 1; ADD JAR s3://a; SELECT 1") + self.assertEqual(len(expressions), 3) + self.assertEqual(expressions[0].sql(), "SET x = 1") + self.assertEqual(expressions[1].sql(), "ADD JAR s3://a") + self.assertEqual(expressions[2].sql(), "SELECT 1") + + def test_identify(self): + expression = parse_one( + """ + SELECT a, "b", c AS c, d AS "D", e AS "y|z'" + FROM y."z" + """ + ) + + assert expression.expressions[0].text("this") == "a" + assert expression.expressions[1].text("this") == "b" + assert expression.expressions[2].text("alias") == "c" + assert expression.expressions[3].text("alias") == "D" + assert expression.expressions[4].text("alias") == "y|z'" + table = expression.args["from"].expressions[0] + assert table.args["this"].args["this"] == "z" + assert table.args["db"].args["this"] == "y" + + def test_multi(self): + expressions = parse( + """ + SELECT * FROM a; SELECT * FROM b; + """ + ) + + assert len(expressions) == 2 + assert ( + expressions[0].args["from"].expressions[0].args["this"].args["this"] == "a" + ) + assert ( + expressions[1].args["from"].expressions[0].args["this"].args["this"] == "b" + ) + + def test_expression(self): + ignore = Parser(error_level=ErrorLevel.IGNORE) + self.assertIsInstance(ignore.expression(exp.Hint, expressions=[""]), exp.Hint) + self.assertIsInstance(ignore.expression(exp.Hint, y=""), exp.Hint) + self.assertIsInstance(ignore.expression(exp.Hint), exp.Hint) + + default = Parser() + self.assertIsInstance(default.expression(exp.Hint, expressions=[""]), exp.Hint) + default.expression(exp.Hint, y="") + default.expression(exp.Hint) + self.assertEqual(len(default.errors), 3) + + warn = Parser(error_level=ErrorLevel.WARN) + warn.expression(exp.Hint, y="") + self.assertEqual(len(warn.errors), 2) + + def test_parse_errors(self): + with self.assertRaises(ParseError): + parse_one("IF(a > 0, a, b, c)") + + with self.assertRaises(ParseError): + parse_one("IF(a > 0)") + + with self.assertRaises(ParseError): + parse_one("WITH cte AS (SELECT * FROM x)") + + def test_space(self): + self.assertEqual( + parse_one("SELECT ROW() OVER(PARTITION BY x) FROM x GROUP BY y").sql(), + "SELECT ROW() OVER (PARTITION BY x) FROM x GROUP BY y", + ) + + self.assertEqual( + parse_one( + """SELECT * FROM x GROUP + BY y""" + ).sql(), + "SELECT * FROM x GROUP BY y", + ) + + def test_missing_by(self): + with self.assertRaises(ParseError): + parse_one("SELECT FROM x ORDER BY") + + def test_annotations(self): + expression = parse_one( + """ + SELECT + a #annotation1, + b as B #annotation2:testing , + "test#annotation",c#annotation3, d #annotation4, + e #, + f # space + FROM foo + """ + ) + + assert expression.expressions[0].name == "annotation1" + assert expression.expressions[1].name == "annotation2:testing " + assert expression.expressions[2].name == "test#annotation" + assert expression.expressions[3].name == "annotation3" + assert expression.expressions[4].name == "annotation4" + assert expression.expressions[5].name == "" + assert expression.expressions[6].name == " space" + + def test_pretty_config_override(self): + self.assertEqual(parse_one("SELECT col FROM x").sql(), "SELECT col FROM x") + with patch("sqlglot.pretty", True): + self.assertEqual( + parse_one("SELECT col FROM x").sql(), "SELECT\n col\nFROM x" + ) + + self.assertEqual( + parse_one("SELECT col FROM x").sql(pretty=True), "SELECT\n col\nFROM x" + ) + + @patch("sqlglot.parser.logger") + def test_comment_error_n(self, logger): + parse_one( + """CREATE TABLE x +( +-- test +)""", + error_level=ErrorLevel.WARN, + ) + + assert_logger_contains( + "Required keyword: 'expressions' missing for . Line 4, Col: 1.", + logger, + ) + + @patch("sqlglot.parser.logger") + def test_comment_error_r(self, logger): + parse_one( + """CREATE TABLE x (-- test\r)""", + error_level=ErrorLevel.WARN, + ) + + assert_logger_contains( + "Required keyword: 'expressions' missing for . Line 2, Col: 1.", + logger, + ) + + @patch("sqlglot.parser.logger") + def test_create_table_error(self, logger): + parse_one( + """CREATE TABLE PARTITION""", + error_level=ErrorLevel.WARN, + ) + + assert_logger_contains( + "Expected table name", + logger, + ) diff --git a/tests/test_time.py b/tests/test_time.py new file mode 100644 index 0000000..17821c2 --- /dev/null +++ b/tests/test_time.py @@ -0,0 +1,14 @@ +import unittest + +from sqlglot.time import format_time + + +class TestTime(unittest.TestCase): + def test_format_time(self): + self.assertEqual(format_time("", {}), "") + self.assertEqual(format_time(" ", {}), " ") + mapping = {"a": "b", "aa": "c"} + self.assertEqual(format_time("a", mapping), "b") + self.assertEqual(format_time("aa", mapping), "c") + self.assertEqual(format_time("aaada", mapping), "cbdb") + self.assertEqual(format_time("da", mapping), "db") diff --git a/tests/test_transforms.py b/tests/test_transforms.py new file mode 100644 index 0000000..2030109 --- /dev/null +++ b/tests/test_transforms.py @@ -0,0 +1,16 @@ +import unittest + +from sqlglot import parse_one +from sqlglot.transforms import unalias_group + + +class TestTime(unittest.TestCase): + def validate(self, transform, sql, target): + self.assertEqual(parse_one(sql).transform(transform).sql(), target) + + def test_unalias_group(self): + self.validate( + unalias_group, + "SELECT a, b AS b, c AS c, 4 FROM x GROUP BY a, b, x.c, 4", + "SELECT a, b AS b, c AS c, 4 FROM x GROUP BY a, 2, x.c, 4", + ) diff --git a/tests/test_transpile.py b/tests/test_transpile.py new file mode 100644 index 0000000..28bcc7a --- /dev/null +++ b/tests/test_transpile.py @@ -0,0 +1,349 @@ +import os +import unittest +from unittest import mock + +from sqlglot import parse_one, transpile +from sqlglot.errors import ErrorLevel, ParseError, UnsupportedError +from tests.helpers import ( + assert_logger_contains, + load_sql_fixture_pairs, + load_sql_fixtures, +) + + +class TestTranspile(unittest.TestCase): + file_dir = os.path.dirname(__file__) + fixtures_dir = os.path.join(file_dir, "fixtures") + maxDiff = None + + def validate(self, sql, target, **kwargs): + self.assertEqual(transpile(sql, **kwargs)[0], target) + + def test_alias(self): + for key in ("union", "filter", "over", "from", "join"): + with self.subTest(f"alias {key}"): + self.validate(f"SELECT x AS {key}", f"SELECT x AS {key}") + self.validate(f'SELECT x "{key}"', f'SELECT x AS "{key}"') + + with self.assertRaises(ParseError): + self.validate(f"SELECT x {key}", "") + + def test_asc(self): + self.validate("SELECT x FROM y ORDER BY x ASC", "SELECT x FROM y ORDER BY x") + + def test_paren(self): + with self.assertRaises(ParseError): + transpile("1 + (2 + 3") + transpile("select f(") + + def test_some(self): + self.validate( + "SELECT * FROM x WHERE a = SOME (SELECT 1)", + "SELECT * FROM x WHERE a = ANY (SELECT 1)", + ) + + def test_space(self): + self.validate("SELECT MIN(3)>MIN(2)", "SELECT MIN(3) > MIN(2)") + self.validate("SELECT MIN(3)>=MIN(2)", "SELECT MIN(3) >= MIN(2)") + self.validate("SELECT 1>0", "SELECT 1 > 0") + self.validate("SELECT 3>=3", "SELECT 3 >= 3") + + def test_comments(self): + self.validate("SELECT 1 FROM foo -- comment", "SELECT 1 FROM foo") + self.validate("SELECT 1 /* inline */ FROM foo -- comment", "SELECT 1 FROM foo") + + self.validate( + """ + SELECT 1 -- comment + FROM foo -- comment + """, + "SELECT 1 FROM foo", + ) + + self.validate( + """ + SELECT 1 /* big comment + like this */ + FROM foo -- comment + """, + "SELECT 1 FROM foo", + ) + + def test_types(self): + self.validate("INT x", "CAST(x AS INT)") + self.validate("VARCHAR x y", "CAST(x AS VARCHAR) AS y") + self.validate("STRING x y", "CAST(x AS TEXT) AS y") + self.validate("x::INT", "CAST(x AS INT)") + self.validate("x::INTEGER", "CAST(x AS INT)") + self.validate("x::INT y", "CAST(x AS INT) AS y") + self.validate("x::INT AS y", "CAST(x AS INT) AS y") + self.validate("x::INT::BOOLEAN", "CAST(CAST(x AS INT) AS BOOLEAN)") + self.validate("CAST(x::INT AS BOOLEAN)", "CAST(CAST(x AS INT) AS BOOLEAN)") + self.validate("CAST(x AS INT)::BOOLEAN", "CAST(CAST(x AS INT) AS BOOLEAN)") + + with self.assertRaises(ParseError): + transpile("x::z") + + def test_not_range(self): + self.validate("a NOT LIKE b", "NOT a LIKE b") + self.validate("a NOT BETWEEN b AND c", "NOT a BETWEEN b AND c") + self.validate("a NOT IN (1, 2)", "NOT a IN (1, 2)") + self.validate("a IS NOT NULL", "NOT a IS NULL") + self.validate("a LIKE TEXT y", "a LIKE CAST(y AS TEXT)") + + def test_extract(self): + self.validate( + "EXTRACT(day FROM '2020-01-01'::TIMESTAMP)", + "EXTRACT(day FROM CAST('2020-01-01' AS TIMESTAMP))", + ) + self.validate( + "EXTRACT(timezone FROM '2020-01-01'::TIMESTAMP)", + "EXTRACT(timezone FROM CAST('2020-01-01' AS TIMESTAMP))", + ) + self.validate( + "EXTRACT(year FROM '2020-01-01'::TIMESTAMP WITH TIME ZONE)", + "EXTRACT(year FROM CAST('2020-01-01' AS TIMESTAMPTZ))", + ) + self.validate( + "extract(month from '2021-01-31'::timestamp without time zone)", + "EXTRACT(month FROM CAST('2021-01-31' AS TIMESTAMP))", + ) + + def test_if(self): + self.validate( + "SELECT IF(a > 1, 1, 0) FROM foo", + "SELECT CASE WHEN a > 1 THEN 1 ELSE 0 END FROM foo", + ) + self.validate( + "SELECT IF a > 1 THEN b END", + "SELECT CASE WHEN a > 1 THEN b END", + ) + self.validate( + "SELECT IF a > 1 THEN b ELSE c END", + "SELECT CASE WHEN a > 1 THEN b ELSE c END", + ) + self.validate( + "SELECT IF(a > 1, 1) FROM foo", "SELECT CASE WHEN a > 1 THEN 1 END FROM foo" + ) + + def test_ignore_nulls(self): + self.validate("SELECT COUNT(x RESPECT NULLS)", "SELECT COUNT(x)") + + def test_time(self): + self.validate("TIMESTAMP '2020-01-01'", "CAST('2020-01-01' AS TIMESTAMP)") + self.validate( + "TIMESTAMP WITH TIME ZONE '2020-01-01'", "CAST('2020-01-01' AS TIMESTAMPTZ)" + ) + self.validate( + "TIMESTAMP(9) WITH TIME ZONE '2020-01-01'", + "CAST('2020-01-01' AS TIMESTAMPTZ(9))", + ) + self.validate( + "TIMESTAMP WITHOUT TIME ZONE '2020-01-01'", + "CAST('2020-01-01' AS TIMESTAMP)", + ) + self.validate("'2020-01-01'::TIMESTAMP", "CAST('2020-01-01' AS TIMESTAMP)") + self.validate( + "'2020-01-01'::TIMESTAMP WITHOUT TIME ZONE", + "CAST('2020-01-01' AS TIMESTAMP)", + ) + self.validate( + "'2020-01-01'::TIMESTAMP WITH TIME ZONE", + "CAST('2020-01-01' AS TIMESTAMPTZ)", + ) + self.validate( + "timestamp with time zone '2025-11-20 00:00:00+00' AT TIME ZONE 'Africa/Cairo'", + "CAST('2025-11-20 00:00:00+00' AS TIMESTAMPTZ) AT TIME ZONE 'Africa/Cairo'", + ) + + self.validate("DATE '2020-01-01'", "CAST('2020-01-01' AS DATE)") + self.validate("'2020-01-01'::DATE", "CAST('2020-01-01' AS DATE)") + self.validate("STR_TO_TIME('x', 'y')", "STRPTIME('x', 'y')", write="duckdb") + self.validate( + "STR_TO_UNIX('x', 'y')", "EPOCH(STRPTIME('x', 'y'))", write="duckdb" + ) + self.validate("TIME_TO_STR(x, 'y')", "STRFTIME(x, 'y')", write="duckdb") + self.validate("TIME_TO_UNIX(x)", "EPOCH(x)", write="duckdb") + self.validate( + "UNIX_TO_STR(123, 'y')", + "STRFTIME(TO_TIMESTAMP(CAST(123 AS BIGINT)), 'y')", + write="duckdb", + ) + self.validate( + "UNIX_TO_TIME(123)", + "TO_TIMESTAMP(CAST(123 AS BIGINT))", + write="duckdb", + ) + + self.validate( + "STR_TO_TIME(x, 'y')", + "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'y')) AS TIMESTAMP)", + write="hive", + ) + self.validate( + "STR_TO_TIME(x, 'yyyy-MM-dd HH:mm:ss')", + "CAST(x AS TIMESTAMP)", + write="hive", + ) + self.validate( + "STR_TO_TIME(x, 'yyyy-MM-dd')", + "CAST(x AS TIMESTAMP)", + write="hive", + ) + + self.validate( + "STR_TO_UNIX('x', 'y')", + "UNIX_TIMESTAMP('x', 'y')", + write="hive", + ) + self.validate("TIME_TO_STR(x, 'y')", "DATE_FORMAT(x, 'y')", write="hive") + + self.validate("TIME_STR_TO_TIME(x)", "TIME_STR_TO_TIME(x)", write=None) + self.validate("TIME_STR_TO_UNIX(x)", "TIME_STR_TO_UNIX(x)", write=None) + self.validate("TIME_TO_TIME_STR(x)", "CAST(x AS TEXT)", write=None) + self.validate("TIME_TO_STR(x, 'y')", "TIME_TO_STR(x, 'y')", write=None) + self.validate("TIME_TO_UNIX(x)", "TIME_TO_UNIX(x)", write=None) + self.validate("UNIX_TO_STR(x, 'y')", "UNIX_TO_STR(x, 'y')", write=None) + self.validate("UNIX_TO_TIME(x)", "UNIX_TO_TIME(x)", write=None) + self.validate("UNIX_TO_TIME_STR(x)", "UNIX_TO_TIME_STR(x)", write=None) + self.validate("TIME_STR_TO_DATE(x)", "TIME_STR_TO_DATE(x)", write=None) + + self.validate("TIME_STR_TO_DATE(x)", "TO_DATE(x)", write="hive") + self.validate( + "UNIX_TO_STR(x, 'yyyy-MM-dd HH:mm:ss')", "FROM_UNIXTIME(x)", write="hive" + ) + self.validate( + "STR_TO_UNIX(x, 'yyyy-MM-dd HH:mm:ss')", "UNIX_TIMESTAMP(x)", write="hive" + ) + self.validate("IF(x > 1, x + 1)", "IF(x > 1, x + 1)", write="presto") + self.validate("IF(x > 1, 1 + 1)", "IF(x > 1, 1 + 1)", write="hive") + self.validate("IF(x > 1, 1, 0)", "IF(x > 1, 1, 0)", write="hive") + + self.validate( + "TIME_TO_UNIX(x)", + "UNIX_TIMESTAMP(x)", + write="hive", + ) + self.validate("UNIX_TO_STR(123, 'y')", "FROM_UNIXTIME(123, 'y')", write="hive") + self.validate( + "UNIX_TO_TIME(123)", + "FROM_UNIXTIME(123)", + write="hive", + ) + + self.validate("STR_TO_TIME('x', 'y')", "DATE_PARSE('x', 'y')", write="presto") + self.validate( + "STR_TO_UNIX('x', 'y')", "TO_UNIXTIME(DATE_PARSE('x', 'y'))", write="presto" + ) + self.validate("TIME_TO_STR(x, 'y')", "DATE_FORMAT(x, 'y')", write="presto") + self.validate("TIME_TO_UNIX(x)", "TO_UNIXTIME(x)", write="presto") + self.validate( + "UNIX_TO_STR(123, 'y')", + "DATE_FORMAT(FROM_UNIXTIME(123), 'y')", + write="presto", + ) + self.validate("UNIX_TO_TIME(123)", "FROM_UNIXTIME(123)", write="presto") + + self.validate("STR_TO_TIME('x', 'y')", "TO_TIMESTAMP('x', 'y')", write="spark") + self.validate( + "STR_TO_UNIX('x', 'y')", "UNIX_TIMESTAMP('x', 'y')", write="spark" + ) + self.validate("TIME_TO_STR(x, 'y')", "DATE_FORMAT(x, 'y')", write="spark") + + self.validate( + "TIME_TO_UNIX(x)", + "UNIX_TIMESTAMP(x)", + write="spark", + ) + self.validate("UNIX_TO_STR(123, 'y')", "FROM_UNIXTIME(123, 'y')", write="spark") + self.validate( + "UNIX_TO_TIME(123)", + "FROM_UNIXTIME(123)", + write="spark", + ) + self.validate( + "CREATE TEMPORARY TABLE test AS SELECT 1", + "CREATE TEMPORARY VIEW test AS SELECT 1", + write="spark", + ) + + @mock.patch("sqlglot.helper.logger") + def test_index_offset(self, mock_logger): + self.validate("x[0]", "x[1]", write="presto", identity=False) + self.validate("x[1]", "x[0]", read="presto", identity=False) + mock_logger.warning.assert_any_call("Applying array index offset (%s)", 1) + mock_logger.warning.assert_any_call("Applying array index offset (%s)", -1) + + def test_identity(self): + self.assertEqual(transpile("")[0], "") + for sql in load_sql_fixtures("identity.sql"): + with self.subTest(sql): + self.assertEqual(transpile(sql)[0], sql.strip()) + + def test_partial(self): + for sql in load_sql_fixtures("partial.sql"): + with self.subTest(sql): + self.assertEqual( + transpile(sql, error_level=ErrorLevel.IGNORE)[0], sql.strip() + ) + + def test_pretty(self): + for _, sql, pretty in load_sql_fixture_pairs("pretty.sql"): + with self.subTest(sql[:100]): + generated = transpile(sql, pretty=True)[0] + self.assertEqual(generated, pretty) + self.assertEqual(parse_one(sql), parse_one(pretty)) + + @mock.patch("sqlglot.parser.logger") + def test_error_level(self, logger): + invalid = "x + 1. (" + errors = [ + "Required keyword: 'expressions' missing for . Line 1, Col: 8.\n x + 1. \033[4m(\033[0m", + "Expecting ). Line 1, Col: 8.\n x + 1. \033[4m(\033[0m", + ] + + transpile(invalid, error_level=ErrorLevel.WARN) + for error in errors: + assert_logger_contains(error, logger) + + with self.assertRaises(ParseError) as ctx: + transpile(invalid, error_level=ErrorLevel.IMMEDIATE) + self.assertEqual(str(ctx.exception), errors[0]) + + with self.assertRaises(ParseError) as ctx: + transpile(invalid, error_level=ErrorLevel.RAISE) + self.assertEqual(str(ctx.exception), "\n\n".join(errors)) + + more_than_max_errors = "((((" + expected = ( + "Expecting ). Line 1, Col: 4.\n (((\033[4m(\033[0m\n\n" + "Required keyword: 'this' missing for . Line 1, Col: 4.\n (((\033[4m(\033[0m\n\n" + "Expecting ). Line 1, Col: 4.\n (((\033[4m(\033[0m\n\n" + "... and 2 more" + ) + with self.assertRaises(ParseError) as ctx: + transpile(more_than_max_errors, error_level=ErrorLevel.RAISE) + self.assertEqual(str(ctx.exception), expected) + + @mock.patch("sqlglot.generator.logger") + def test_unsupported_level(self, logger): + def unsupported(level): + transpile( + "SELECT MAP(a, b), MAP(a, b), MAP(a, b), MAP(a, b)", + read="presto", + write="hive", + unsupported_level=level, + ) + + error = "Cannot convert array columns into map use SparkSQL instead." + + unsupported(ErrorLevel.WARN) + assert_logger_contains("\n".join([error] * 4), logger, level="warning") + + with self.assertRaises(UnsupportedError) as ctx: + unsupported(ErrorLevel.RAISE) + self.assertEqual(str(ctx.exception).count(error), 3) + + with self.assertRaises(UnsupportedError) as ctx: + unsupported(ErrorLevel.IMMEDIATE) + self.assertEqual(str(ctx.exception).count(error), 1) -- cgit v1.2.3