summaryrefslogtreecommitdiffstats
path: root/tests
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2023-05-10 06:44:54 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2023-05-10 06:44:54 +0000
commitd2e9401b18925b5702c5c758af7d4f5b61deb493 (patch)
tree58dbf490c0457c2908751b3e4b63af13287381ee /tests
parentAdding upstream version 11.7.1. (diff)
downloadsqlglot-d2e9401b18925b5702c5c758af7d4f5b61deb493.tar.xz
sqlglot-d2e9401b18925b5702c5c758af7d4f5b61deb493.zip
Adding upstream version 12.2.0.upstream/12.2.0
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'tests')
-rw-r--r--tests/dataframe/integration/dataframe_validator.py6
-rw-r--r--tests/dialects/test_bigquery.py38
-rw-r--r--tests/dialects/test_clickhouse.py21
-rw-r--r--tests/dialects/test_dialect.py4
-rw-r--r--tests/dialects/test_duckdb.py27
-rw-r--r--tests/dialects/test_mysql.py126
-rw-r--r--tests/dialects/test_oracle.py18
-rw-r--r--tests/dialects/test_postgres.py76
-rw-r--r--tests/dialects/test_presto.py45
-rw-r--r--tests/dialects/test_redshift.py15
-rw-r--r--tests/dialects/test_snowflake.py31
-rw-r--r--tests/dialects/test_spark.py47
-rw-r--r--tests/dialects/test_starrocks.py8
-rw-r--r--tests/dialects/test_tsql.py12
-rw-r--r--tests/fixtures/identity.sql6
-rw-r--r--tests/fixtures/optimizer/qualify_columns.sql7
-rw-r--r--tests/fixtures/optimizer/qualify_tables.sql23
-rw-r--r--tests/fixtures/optimizer/tpc-ds/tpc-ds.sql16
-rw-r--r--tests/test_build.py19
-rw-r--r--tests/test_expressions.py5
-rw-r--r--tests/test_optimizer.py13
-rw-r--r--tests/test_tokens.py11
-rw-r--r--tests/test_transpile.py17
23 files changed, 466 insertions, 125 deletions
diff --git a/tests/dataframe/integration/dataframe_validator.py b/tests/dataframe/integration/dataframe_validator.py
index 16f8922..c84a342 100644
--- a/tests/dataframe/integration/dataframe_validator.py
+++ b/tests/dataframe/integration/dataframe_validator.py
@@ -3,17 +3,13 @@ import unittest
import warnings
import sqlglot
-from sqlglot.helper import PYTHON_VERSION
from tests.helpers import SKIP_INTEGRATION
if t.TYPE_CHECKING:
from pyspark.sql import DataFrame as SparkDataFrame
-@unittest.skipIf(
- SKIP_INTEGRATION or PYTHON_VERSION > (3, 10),
- "Skipping Integration Tests since `SKIP_INTEGRATION` is set",
-)
+@unittest.skipIf(SKIP_INTEGRATION, "Skipping Integration Tests since `SKIP_INTEGRATION` is set")
class DataFrameValidator(unittest.TestCase):
spark = None
sqlglot = None
diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py
index 703b7dc..87bba6f 100644
--- a/tests/dialects/test_bigquery.py
+++ b/tests/dialects/test_bigquery.py
@@ -6,10 +6,19 @@ class TestBigQuery(Validator):
dialect = "bigquery"
def test_bigquery(self):
+ self.validate_identity("DATE_TRUNC(col, WEEK(MONDAY))")
+ self.validate_identity("SELECT b'abc'")
+ self.validate_identity("""SELECT * FROM UNNEST(ARRAY<STRUCT<x INT64>>[1, 2])""")
self.validate_identity("SELECT AS STRUCT 1 AS a, 2 AS b")
+ self.validate_identity("SELECT DISTINCT AS STRUCT 1 AS a, 2 AS b")
self.validate_identity("SELECT AS VALUE STRUCT(1 AS a, 2 AS b)")
self.validate_identity("SELECT STRUCT<ARRAY<STRING>>(['2023-01-17'])")
self.validate_identity("SELECT * FROM q UNPIVOT(values FOR quarter IN (b, c))")
+ self.validate_identity("""CREATE TABLE x (a STRUCT<values ARRAY<INT64>>)""")
+ self.validate_identity("""CREATE TABLE x (a STRUCT<b STRING OPTIONS (description='b')>)""")
+ self.validate_identity(
+ """CREATE TABLE x (a STRING OPTIONS (description='x')) OPTIONS (table_expiration_days=1)"""
+ )
self.validate_identity(
"SELECT * FROM (SELECT * FROM `t`) AS a UNPIVOT((c) FOR c_name IN (v1, v2))"
)
@@ -98,6 +107,16 @@ class TestBigQuery(Validator):
},
)
self.validate_all(
+ "CAST(a AS BYTES)",
+ write={
+ "bigquery": "CAST(a AS BYTES)",
+ "duckdb": "CAST(a AS BLOB)",
+ "presto": "CAST(a AS VARBINARY)",
+ "hive": "CAST(a AS BINARY)",
+ "spark": "CAST(a AS BINARY)",
+ },
+ )
+ self.validate_all(
"CAST(a AS NUMERIC)",
write={
"bigquery": "CAST(a AS NUMERIC)",
@@ -173,7 +192,6 @@ class TestBigQuery(Validator):
"current_datetime",
write={
"bigquery": "CURRENT_DATETIME()",
- "duckdb": "CURRENT_DATETIME()",
"presto": "CURRENT_DATETIME()",
"hive": "CURRENT_DATETIME()",
"spark": "CURRENT_DATETIME()",
@@ -183,7 +201,7 @@ class TestBigQuery(Validator):
"current_time",
write={
"bigquery": "CURRENT_TIME()",
- "duckdb": "CURRENT_TIME()",
+ "duckdb": "CURRENT_TIME",
"presto": "CURRENT_TIME()",
"hive": "CURRENT_TIME()",
"spark": "CURRENT_TIME()",
@@ -193,7 +211,7 @@ class TestBigQuery(Validator):
"current_timestamp",
write={
"bigquery": "CURRENT_TIMESTAMP()",
- "duckdb": "CURRENT_TIMESTAMP()",
+ "duckdb": "CURRENT_TIMESTAMP",
"postgres": "CURRENT_TIMESTAMP",
"presto": "CURRENT_TIMESTAMP",
"hive": "CURRENT_TIMESTAMP()",
@@ -204,7 +222,7 @@ class TestBigQuery(Validator):
"current_timestamp()",
write={
"bigquery": "CURRENT_TIMESTAMP()",
- "duckdb": "CURRENT_TIMESTAMP()",
+ "duckdb": "CURRENT_TIMESTAMP",
"postgres": "CURRENT_TIMESTAMP",
"presto": "CURRENT_TIMESTAMP",
"hive": "CURRENT_TIMESTAMP()",
@@ -343,6 +361,18 @@ class TestBigQuery(Validator):
},
)
self.validate_all(
+ "SELECT cola, colb FROM (VALUES (1, 'test')) AS tab",
+ write={
+ "bigquery": "SELECT cola, colb FROM UNNEST([STRUCT(1 AS _c0, 'test' AS _c1)])",
+ },
+ )
+ self.validate_all(
+ "SELECT cola, colb FROM (VALUES (1, 'test'))",
+ write={
+ "bigquery": "SELECT cola, colb FROM UNNEST([STRUCT(1 AS _c0, 'test' AS _c1)])",
+ },
+ )
+ self.validate_all(
"SELECT cola, colb, colc FROM (VALUES (1, 'test', NULL)) AS tab(cola, colb, colc)",
write={
"spark": "SELECT cola, colb, colc FROM VALUES (1, 'test', NULL) AS tab(cola, colb, colc)",
diff --git a/tests/dialects/test_clickhouse.py b/tests/dialects/test_clickhouse.py
index 9fd2b45..1060881 100644
--- a/tests/dialects/test_clickhouse.py
+++ b/tests/dialects/test_clickhouse.py
@@ -65,3 +65,24 @@ class TestClickhouse(Validator):
self.validate_identity("WITH SUM(bytes) AS foo SELECT foo FROM system.parts")
self.validate_identity("WITH (SELECT foo) AS bar SELECT bar + 5")
self.validate_identity("WITH test1 AS (SELECT i + 1, j + 1 FROM test1) SELECT * FROM test1")
+
+ def test_signed_and_unsigned_types(self):
+ data_types = [
+ "UInt8",
+ "UInt16",
+ "UInt32",
+ "UInt64",
+ "UInt128",
+ "UInt256",
+ "Int8",
+ "Int16",
+ "Int32",
+ "Int64",
+ "Int128",
+ "Int256",
+ ]
+ for data_type in data_types:
+ self.validate_all(
+ f"pow(2, 32)::{data_type}",
+ write={"clickhouse": f"CAST(POWER(2, 32) AS {data_type})"},
+ )
diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py
index bcbbfd6..f12273b 100644
--- a/tests/dialects/test_dialect.py
+++ b/tests/dialects/test_dialect.py
@@ -95,7 +95,7 @@ class TestDialect(Validator):
self.validate_all(
"CAST(a AS BINARY(4))",
write={
- "bigquery": "CAST(a AS BINARY)",
+ "bigquery": "CAST(a AS BYTES)",
"clickhouse": "CAST(a AS BINARY(4))",
"drill": "CAST(a AS VARBINARY(4))",
"duckdb": "CAST(a AS BLOB(4))",
@@ -114,7 +114,7 @@ class TestDialect(Validator):
self.validate_all(
"CAST(a AS VARBINARY(4))",
write={
- "bigquery": "CAST(a AS VARBINARY)",
+ "bigquery": "CAST(a AS BYTES)",
"clickhouse": "CAST(a AS VARBINARY(4))",
"duckdb": "CAST(a AS BLOB(4))",
"mysql": "CAST(a AS VARBINARY(4))",
diff --git a/tests/dialects/test_duckdb.py b/tests/dialects/test_duckdb.py
index 9e0040c..8c1b748 100644
--- a/tests/dialects/test_duckdb.py
+++ b/tests/dialects/test_duckdb.py
@@ -6,6 +6,9 @@ class TestDuckDB(Validator):
dialect = "duckdb"
def test_time(self):
+ self.validate_identity("SELECT CURRENT_DATE")
+ self.validate_identity("SELECT CURRENT_TIMESTAMP")
+
self.validate_all(
"EPOCH(x)",
read={
@@ -24,7 +27,7 @@ class TestDuckDB(Validator):
"bigquery": "UNIX_TO_TIME(x / 1000)",
"duckdb": "TO_TIMESTAMP(x / 1000)",
"presto": "FROM_UNIXTIME(x / 1000)",
- "spark": "FROM_UNIXTIME(x / 1000)",
+ "spark": "CAST(FROM_UNIXTIME(x / 1000) AS TIMESTAMP)",
},
)
self.validate_all(
@@ -124,18 +127,34 @@ class TestDuckDB(Validator):
self.validate_identity("SELECT {'a': 1} AS x")
self.validate_identity("SELECT {'a': {'b': {'c': 1}}, 'd': {'e': 2}} AS x")
self.validate_identity("SELECT {'x': 1, 'y': 2, 'z': 3}")
- self.validate_identity(
- "SELECT {'yes': 'duck', 'maybe': 'goose', 'huh': NULL, 'no': 'heron'}"
- )
self.validate_identity("SELECT {'key1': 'string', 'key2': 1, 'key3': 12.345}")
self.validate_identity("SELECT ROW(x, x + 1, y) FROM (SELECT 1 AS x, 'a' AS y)")
self.validate_identity("SELECT (x, x + 1, y) FROM (SELECT 1 AS x, 'a' AS y)")
self.validate_identity("SELECT a.x FROM (SELECT {'x': 1, 'y': 2, 'z': 3} AS a)")
self.validate_identity("ATTACH DATABASE ':memory:' AS new_database")
self.validate_identity(
+ "SELECT {'yes': 'duck', 'maybe': 'goose', 'huh': NULL, 'no': 'heron'}"
+ )
+ self.validate_identity(
"SELECT a['x space'] FROM (SELECT {'x space': 1, 'y': 2, 'z': 3} AS a)"
)
+ self.validate_all("0b1010", write={"": "0 AS b1010"})
+ self.validate_all("0x1010", write={"": "0 AS x1010"})
+ self.validate_all(
+ """SELECT DATEDIFF('day', t1."A", t1."B") FROM "table" AS t1""",
+ write={
+ "duckdb": """SELECT DATE_DIFF('day', t1."A", t1."B") FROM "table" AS t1""",
+ "trino": """SELECT DATE_DIFF('day', t1."A", t1."B") FROM "table" AS t1""",
+ },
+ )
+ self.validate_all(
+ "SELECT DATE_DIFF('day', DATE '2020-01-01', DATE '2020-01-05')",
+ write={
+ "duckdb": "SELECT DATE_DIFF('day', CAST('2020-01-01' AS DATE), CAST('2020-01-05' AS DATE))",
+ "trino": "SELECT DATE_DIFF('day', CAST('2020-01-01' AS DATE), CAST('2020-01-05' AS DATE))",
+ },
+ )
self.validate_all("x ~ y", write={"duckdb": "REGEXP_MATCHES(x, y)"})
self.validate_all("SELECT * FROM 'x.y'", write={"duckdb": 'SELECT * FROM "x.y"'})
self.validate_all(
diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py
index 524d95e..f31b1b9 100644
--- a/tests/dialects/test_mysql.py
+++ b/tests/dialects/test_mysql.py
@@ -12,6 +12,7 @@ class TestMySQL(Validator):
"duckdb": "CREATE TABLE z (a INT)",
"mysql": "CREATE TABLE z (a INT) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARACTER SET=utf8 COLLATE=utf8_bin COMMENT='x'",
"spark": "CREATE TABLE z (a INT) COMMENT 'x'",
+ "sqlite": "CREATE TABLE z (a INTEGER)",
},
)
self.validate_all(
@@ -24,6 +25,19 @@ class TestMySQL(Validator):
"INSERT INTO x VALUES (1, 'a', 2.0) ON DUPLICATE KEY UPDATE SET x.id = 1"
)
+ self.validate_all(
+ "CREATE TABLE x (id int not null auto_increment, primary key (id))",
+ write={
+ "sqlite": "CREATE TABLE x (id INTEGER NOT NULL AUTOINCREMENT PRIMARY KEY)",
+ },
+ )
+ self.validate_all(
+ "CREATE TABLE x (id int not null auto_increment)",
+ write={
+ "sqlite": "CREATE TABLE x (id INTEGER NOT NULL)",
+ },
+ )
+
def test_identity(self):
self.validate_identity("SELECT CURRENT_TIMESTAMP(6)")
self.validate_identity("x ->> '$.name'")
@@ -150,47 +164,81 @@ class TestMySQL(Validator):
)
def test_hexadecimal_literal(self):
- self.validate_all(
- "SELECT 0xCC",
- write={
- "mysql": "SELECT x'CC'",
- "sqlite": "SELECT x'CC'",
- "spark": "SELECT X'CC'",
- "trino": "SELECT X'CC'",
- "bigquery": "SELECT 0xCC",
- "oracle": "SELECT 204",
- },
- )
- self.validate_all(
- "SELECT X'1A'",
- write={
- "mysql": "SELECT x'1A'",
- },
- )
- self.validate_all(
- "SELECT 0xz",
- write={
- "mysql": "SELECT `0xz`",
- },
- )
+ write_CC = {
+ "bigquery": "SELECT 0xCC",
+ "clickhouse": "SELECT 0xCC",
+ "databricks": "SELECT 204",
+ "drill": "SELECT 204",
+ "duckdb": "SELECT 204",
+ "hive": "SELECT 204",
+ "mysql": "SELECT x'CC'",
+ "oracle": "SELECT 204",
+ "postgres": "SELECT x'CC'",
+ "presto": "SELECT 204",
+ "redshift": "SELECT 204",
+ "snowflake": "SELECT x'CC'",
+ "spark": "SELECT X'CC'",
+ "sqlite": "SELECT x'CC'",
+ "starrocks": "SELECT x'CC'",
+ "tableau": "SELECT 204",
+ "teradata": "SELECT 204",
+ "trino": "SELECT X'CC'",
+ "tsql": "SELECT 0xCC",
+ }
+ write_CC_with_leading_zeros = {
+ "bigquery": "SELECT 0x0000CC",
+ "clickhouse": "SELECT 0x0000CC",
+ "databricks": "SELECT 204",
+ "drill": "SELECT 204",
+ "duckdb": "SELECT 204",
+ "hive": "SELECT 204",
+ "mysql": "SELECT x'0000CC'",
+ "oracle": "SELECT 204",
+ "postgres": "SELECT x'0000CC'",
+ "presto": "SELECT 204",
+ "redshift": "SELECT 204",
+ "snowflake": "SELECT x'0000CC'",
+ "spark": "SELECT X'0000CC'",
+ "sqlite": "SELECT x'0000CC'",
+ "starrocks": "SELECT x'0000CC'",
+ "tableau": "SELECT 204",
+ "teradata": "SELECT 204",
+ "trino": "SELECT X'0000CC'",
+ "tsql": "SELECT 0x0000CC",
+ }
+
+ self.validate_all("SELECT X'1A'", write={"mysql": "SELECT x'1A'"})
+ self.validate_all("SELECT 0xz", write={"mysql": "SELECT `0xz`"})
+ self.validate_all("SELECT 0xCC", write=write_CC)
+ self.validate_all("SELECT 0xCC ", write=write_CC)
+ self.validate_all("SELECT x'CC'", write=write_CC)
+ self.validate_all("SELECT 0x0000CC", write=write_CC_with_leading_zeros)
+ self.validate_all("SELECT x'0000CC'", write=write_CC_with_leading_zeros)
def test_bits_literal(self):
- self.validate_all(
- "SELECT 0b1011",
- write={
- "mysql": "SELECT b'1011'",
- "postgres": "SELECT b'1011'",
- "oracle": "SELECT 11",
- },
- )
- self.validate_all(
- "SELECT B'1011'",
- write={
- "mysql": "SELECT b'1011'",
- "postgres": "SELECT b'1011'",
- "oracle": "SELECT 11",
- },
- )
+ write_1011 = {
+ "bigquery": "SELECT 11",
+ "clickhouse": "SELECT 0b1011",
+ "databricks": "SELECT 11",
+ "drill": "SELECT 11",
+ "hive": "SELECT 11",
+ "mysql": "SELECT b'1011'",
+ "oracle": "SELECT 11",
+ "postgres": "SELECT b'1011'",
+ "presto": "SELECT 11",
+ "redshift": "SELECT 11",
+ "snowflake": "SELECT 11",
+ "spark": "SELECT 11",
+ "sqlite": "SELECT 11",
+ "mysql": "SELECT b'1011'",
+ "tableau": "SELECT 11",
+ "teradata": "SELECT 11",
+ "trino": "SELECT 11",
+ "tsql": "SELECT 11",
+ }
+
+ self.validate_all("SELECT 0b1011", write=write_1011)
+ self.validate_all("SELECT b'1011'", write=write_1011)
def test_string_literals(self):
self.validate_all(
diff --git a/tests/dialects/test_oracle.py b/tests/dialects/test_oracle.py
index dd297d6..88c79fd 100644
--- a/tests/dialects/test_oracle.py
+++ b/tests/dialects/test_oracle.py
@@ -5,6 +5,8 @@ class TestOracle(Validator):
dialect = "oracle"
def test_oracle(self):
+ self.validate_identity("SELECT * FROM table_name@dblink_name.database_link_domain")
+ self.validate_identity("SELECT * FROM table_name SAMPLE (25) s")
self.validate_identity("SELECT * FROM V$SESSION")
self.validate_identity(
"SELECT MIN(column_name) KEEP (DENSE_RANK FIRST ORDER BY column_name DESC) FROM table_name"
@@ -17,7 +19,6 @@ class TestOracle(Validator):
"": "IFNULL(NULL, 1)",
},
)
-
self.validate_all(
"DATE '2022-01-01'",
write={
@@ -28,6 +29,21 @@ class TestOracle(Validator):
},
)
+ self.validate_all(
+ "x::binary_double",
+ write={
+ "oracle": "CAST(x AS DOUBLE PRECISION)",
+ "": "CAST(x AS DOUBLE)",
+ },
+ )
+ self.validate_all(
+ "x::binary_float",
+ write={
+ "oracle": "CAST(x AS FLOAT)",
+ "": "CAST(x AS FLOAT)",
+ },
+ )
+
def test_join_marker(self):
self.validate_identity("SELECT e1.x, e2.x FROM e e1, e e2 WHERE e1.y (+) = e2.y")
self.validate_identity("SELECT e1.x, e2.x FROM e e1, e e2 WHERE e1.y = e2.y (+)")
diff --git a/tests/dialects/test_postgres.py b/tests/dialects/test_postgres.py
index e2f9c41..b535a84 100644
--- a/tests/dialects/test_postgres.py
+++ b/tests/dialects/test_postgres.py
@@ -98,6 +98,21 @@ class TestPostgres(Validator):
self.validate_identity("STRING_AGG(x, ',' ORDER BY y DESC)")
self.validate_identity("STRING_AGG(DISTINCT x, ',' ORDER BY y DESC)")
self.validate_identity("SELECT CASE WHEN SUBSTRING('abcdefg') IN ('ab') THEN 1 ELSE 0 END")
+ self.validate_identity("COMMENT ON TABLE mytable IS 'this'")
+ self.validate_identity("SELECT e'\\xDEADBEEF'")
+ self.validate_identity("SELECT CAST(e'\\176' AS BYTEA)")
+ self.validate_identity("SELECT * FROM x WHERE SUBSTRING('Thomas' FROM '...$') IN ('mas')")
+ self.validate_identity("SELECT TRIM(' X' FROM ' XXX ')")
+ self.validate_identity("SELECT TRIM(LEADING 'bla' FROM ' XXX ' COLLATE utf8_bin)")
+ self.validate_identity("""SELECT * FROM JSON_TO_RECORDSET(z) AS y("rank" INT)""")
+ self.validate_identity("x ~ 'y'")
+ self.validate_identity("x ~* 'y'")
+ self.validate_identity(
+ "SELECT SUM(x) OVER a, SUM(y) OVER b FROM c WINDOW a AS (PARTITION BY d), b AS (PARTITION BY e)"
+ )
+ self.validate_identity(
+ "CREATE TABLE A (LIKE B INCLUDING CONSTRAINT INCLUDING COMPRESSION EXCLUDING COMMENTS)"
+ )
self.validate_identity(
"SELECT CASE WHEN SUBSTRING('abcdefg' FROM 1) IN ('ab') THEN 1 ELSE 0 END"
)
@@ -107,37 +122,31 @@ class TestPostgres(Validator):
self.validate_identity(
'SELECT * FROM "x" WHERE SUBSTRING("x"."foo" FROM 1 FOR 2) IN (\'mas\')'
)
- self.validate_identity("SELECT * FROM x WHERE SUBSTRING('Thomas' FROM '...$') IN ('mas')")
self.validate_identity(
"SELECT * FROM x WHERE SUBSTRING('Thomas' FROM '%#\"o_a#\"_' FOR '#') IN ('mas')"
)
self.validate_identity(
"SELECT SUBSTRING('bla' + 'foo' || 'bar' FROM 3 - 1 + 5 FOR 4 + SOME_FUNC(arg1, arg2))"
)
- self.validate_identity("SELECT TRIM(' X' FROM ' XXX ')")
- self.validate_identity("SELECT TRIM(LEADING 'bla' FROM ' XXX ' COLLATE utf8_bin)")
self.validate_identity(
"SELECT TO_TIMESTAMP(1284352323.5), TO_TIMESTAMP('05 Dec 2000', 'DD Mon YYYY')"
)
- self.validate_identity("COMMENT ON TABLE mytable IS 'this'")
- self.validate_identity("SELECT e'\\xDEADBEEF'")
- self.validate_identity("SELECT CAST(e'\\176' AS BYTEA)")
+
+ self.validate_all(
+ "SELECT PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY amount)",
+ write={
+ "databricks": "SELECT PERCENTILE_APPROX(amount, 0.5)",
+ "presto": "SELECT APPROX_PERCENTILE(amount, 0.5)",
+ "spark": "SELECT PERCENTILE_APPROX(amount, 0.5)",
+ "trino": "SELECT APPROX_PERCENTILE(amount, 0.5)",
+ },
+ )
self.validate_all(
"e'x'",
write={
"mysql": "x",
},
)
- self.validate_identity("""SELECT * FROM JSON_TO_RECORDSET(z) AS y("rank" INT)""")
- self.validate_identity(
- "SELECT SUM(x) OVER a, SUM(y) OVER b FROM c WINDOW a AS (PARTITION BY d), b AS (PARTITION BY e)"
- )
- self.validate_identity(
- "CREATE TABLE A (LIKE B INCLUDING CONSTRAINT INCLUDING COMPRESSION EXCLUDING COMMENTS)"
- )
- self.validate_identity("x ~ 'y'")
- self.validate_identity("x ~* 'y'")
-
self.validate_all(
"SELECT DATE_PART('isodow'::varchar(6), current_date)",
write={
@@ -198,6 +207,33 @@ class TestPostgres(Validator):
},
)
self.validate_all(
+ "GENERATE_SERIES(a, b)",
+ write={
+ "postgres": "GENERATE_SERIES(a, b)",
+ "presto": "SEQUENCE(a, b)",
+ "trino": "SEQUENCE(a, b)",
+ "tsql": "GENERATE_SERIES(a, b)",
+ },
+ )
+ self.validate_all(
+ "GENERATE_SERIES(a, b)",
+ read={
+ "postgres": "GENERATE_SERIES(a, b)",
+ "presto": "SEQUENCE(a, b)",
+ "trino": "SEQUENCE(a, b)",
+ "tsql": "GENERATE_SERIES(a, b)",
+ },
+ )
+ self.validate_all(
+ "SELECT * FROM t CROSS JOIN GENERATE_SERIES(2, 4)",
+ write={
+ "postgres": "SELECT * FROM t CROSS JOIN GENERATE_SERIES(2, 4)",
+ "presto": "SELECT * FROM t CROSS JOIN UNNEST(SEQUENCE(2, 4))",
+ "trino": "SELECT * FROM t CROSS JOIN UNNEST(SEQUENCE(2, 4))",
+ "tsql": "SELECT * FROM t CROSS JOIN GENERATE_SERIES(2, 4)",
+ },
+ )
+ self.validate_all(
"END WORK AND NO CHAIN",
write={"postgres": "COMMIT AND NO CHAIN"},
)
@@ -464,6 +500,14 @@ class TestPostgres(Validator):
},
)
+ self.validate_all(
+ "x / y ^ z",
+ write={
+ "": "x / POWER(y, z)",
+ "postgres": "x / y ^ z",
+ },
+ )
+
self.assertIsInstance(parse_one("id::UUID", read="postgres"), exp.TryCast)
def test_bool_or(self):
diff --git a/tests/dialects/test_presto.py b/tests/dialects/test_presto.py
index 3080476..15962cc 100644
--- a/tests/dialects/test_presto.py
+++ b/tests/dialects/test_presto.py
@@ -7,6 +7,26 @@ class TestPresto(Validator):
def test_cast(self):
self.validate_all(
+ "FROM_BASE64(x)",
+ read={
+ "hive": "UNBASE64(x)",
+ },
+ write={
+ "hive": "UNBASE64(x)",
+ "presto": "FROM_BASE64(x)",
+ },
+ )
+ self.validate_all(
+ "TO_BASE64(x)",
+ read={
+ "hive": "BASE64(x)",
+ },
+ write={
+ "hive": "BASE64(x)",
+ "presto": "TO_BASE64(x)",
+ },
+ )
+ self.validate_all(
"CAST(a AS ARRAY(INT))",
write={
"bigquery": "CAST(a AS ARRAY<INT64>)",
@@ -105,6 +125,13 @@ class TestPresto(Validator):
"spark": "SIZE(x)",
},
)
+ self.validate_all(
+ "ARRAY_JOIN(x, '-', 'a')",
+ write={
+ "hive": "CONCAT_WS('-', x)",
+ "spark": "ARRAY_JOIN(x, '-', 'a')",
+ },
+ )
def test_interval_plural_to_singular(self):
# Microseconds, weeks and quarters are not supported in Presto/Trino INTERVAL literals
@@ -134,6 +161,14 @@ class TestPresto(Validator):
self.validate_identity("VAR_POP(a)")
self.validate_all(
+ "SELECT FROM_UNIXTIME(col) FROM tbl",
+ write={
+ "presto": "SELECT FROM_UNIXTIME(col) FROM tbl",
+ "spark": "SELECT CAST(FROM_UNIXTIME(col) AS TIMESTAMP) FROM tbl",
+ "trino": "SELECT FROM_UNIXTIME(col) FROM tbl",
+ },
+ )
+ self.validate_all(
"DATE_FORMAT(x, '%Y-%m-%d %H:%i:%S')",
write={
"duckdb": "STRFTIME(x, '%Y-%m-%d %H:%M:%S')",
@@ -181,7 +216,7 @@ class TestPresto(Validator):
"duckdb": "TO_TIMESTAMP(x)",
"presto": "FROM_UNIXTIME(x)",
"hive": "FROM_UNIXTIME(x)",
- "spark": "FROM_UNIXTIME(x)",
+ "spark": "CAST(FROM_UNIXTIME(x) AS TIMESTAMP)",
},
)
self.validate_all(
@@ -583,6 +618,14 @@ class TestPresto(Validator):
},
)
+ self.validate_all(
+ "JSON_FORMAT(JSON 'x')",
+ write={
+ "presto": "JSON_FORMAT(CAST('x' AS JSON))",
+ "spark": "TO_JSON('x')",
+ },
+ )
+
def test_encode_decode(self):
self.validate_all(
"TO_UTF8(x)",
diff --git a/tests/dialects/test_redshift.py b/tests/dialects/test_redshift.py
index e5bd0e5..f75480e 100644
--- a/tests/dialects/test_redshift.py
+++ b/tests/dialects/test_redshift.py
@@ -101,7 +101,22 @@ class TestRedshift(Validator):
self.validate_all(
"SELECT DISTINCT ON (a) a, b FROM x ORDER BY c DESC",
write={
+ "bigquery": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE `_row_number` = 1",
+ "databricks": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE `_row_number` = 1",
+ "drill": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE `_row_number` = 1",
+ "hive": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE `_row_number` = 1",
+ "mysql": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) WHERE `_row_number` = 1",
+ "oracle": 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE "_row_number" = 1',
+ "presto": 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE "_row_number" = 1',
"redshift": 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) WHERE "_row_number" = 1',
+ "snowflake": 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) WHERE "_row_number" = 1',
+ "spark": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE `_row_number` = 1",
+ "sqlite": 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE "_row_number" = 1',
+ "starrocks": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) WHERE `_row_number` = 1",
+ "tableau": 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE "_row_number" = 1',
+ "teradata": 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE "_row_number" = 1',
+ "trino": 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE "_row_number" = 1',
+ "tsql": 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE "_row_number" = 1',
},
)
self.validate_all(
diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py
index 5c8b096..57ee235 100644
--- a/tests/dialects/test_snowflake.py
+++ b/tests/dialects/test_snowflake.py
@@ -227,7 +227,7 @@ class TestSnowflake(Validator):
write={
"bigquery": "SELECT UNIX_TO_TIME(1659981729)",
"snowflake": "SELECT TO_TIMESTAMP(1659981729)",
- "spark": "SELECT FROM_UNIXTIME(1659981729)",
+ "spark": "SELECT CAST(FROM_UNIXTIME(1659981729) AS TIMESTAMP)",
},
)
self.validate_all(
@@ -243,7 +243,7 @@ class TestSnowflake(Validator):
write={
"bigquery": "SELECT UNIX_TO_TIME('1659981729')",
"snowflake": "SELECT TO_TIMESTAMP('1659981729')",
- "spark": "SELECT FROM_UNIXTIME('1659981729')",
+ "spark": "SELECT CAST(FROM_UNIXTIME('1659981729') AS TIMESTAMP)",
},
)
self.validate_all(
@@ -401,7 +401,7 @@ class TestSnowflake(Validator):
self.validate_all(
r"SELECT FIRST_VALUE(TABLE1.COLUMN1 RESPECT NULLS) OVER (PARTITION BY RANDOM_COLUMN1, RANDOM_COLUMN2 ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS MY_ALIAS FROM TABLE1",
write={
- "snowflake": r"SELECT FIRST_VALUE(TABLE1.COLUMN1) OVER (PARTITION BY RANDOM_COLUMN1, RANDOM_COLUMN2 ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS MY_ALIAS FROM TABLE1"
+ "snowflake": r"SELECT FIRST_VALUE(TABLE1.COLUMN1 RESPECT NULLS) OVER (PARTITION BY RANDOM_COLUMN1, RANDOM_COLUMN2 ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS MY_ALIAS FROM TABLE1"
},
)
self.validate_all(
@@ -426,30 +426,19 @@ class TestSnowflake(Validator):
def test_sample(self):
self.validate_identity("SELECT * FROM testtable TABLESAMPLE BERNOULLI (20.3)")
self.validate_identity("SELECT * FROM testtable TABLESAMPLE (100)")
+ self.validate_identity("SELECT * FROM testtable TABLESAMPLE SYSTEM (3) SEED (82)")
+ self.validate_identity("SELECT * FROM testtable TABLESAMPLE (10 ROWS)")
+ self.validate_identity("SELECT * FROM testtable SAMPLE (10)")
+ self.validate_identity("SELECT * FROM testtable SAMPLE ROW (0)")
+ self.validate_identity("SELECT a FROM test SAMPLE BLOCK (0.5) SEED (42)")
self.validate_identity(
"SELECT i, j FROM table1 AS t1 INNER JOIN table2 AS t2 TABLESAMPLE (50) WHERE t2.j = t1.i"
)
self.validate_identity(
"SELECT * FROM (SELECT * FROM t1 JOIN t2 ON t1.a = t2.c) TABLESAMPLE (1)"
)
- self.validate_identity("SELECT * FROM testtable TABLESAMPLE SYSTEM (3) SEED (82)")
- self.validate_identity("SELECT * FROM testtable TABLESAMPLE (10 ROWS)")
self.validate_all(
- "SELECT * FROM testtable SAMPLE (10)",
- write={"snowflake": "SELECT * FROM testtable TABLESAMPLE (10)"},
- )
- self.validate_all(
- "SELECT * FROM testtable SAMPLE ROW (0)",
- write={"snowflake": "SELECT * FROM testtable TABLESAMPLE ROW (0)"},
- )
- self.validate_all(
- "SELECT a FROM test SAMPLE BLOCK (0.5) SEED (42)",
- write={
- "snowflake": "SELECT a FROM test TABLESAMPLE BLOCK (0.5) SEED (42)",
- },
- )
- self.validate_all(
"""
SELECT i, j
FROM
@@ -458,13 +447,13 @@ class TestSnowflake(Validator):
table2 AS t2 SAMPLE (50) -- 50% of rows in table2
WHERE t2.j = t1.i""",
write={
- "snowflake": "SELECT i, j FROM table1 AS t1 TABLESAMPLE (25) /* 25% of rows in table1 */ INNER JOIN table2 AS t2 TABLESAMPLE (50) /* 50% of rows in table2 */ WHERE t2.j = t1.i",
+ "snowflake": "SELECT i, j FROM table1 AS t1 SAMPLE (25) /* 25% of rows in table1 */ INNER JOIN table2 AS t2 SAMPLE (50) /* 50% of rows in table2 */ WHERE t2.j = t1.i",
},
)
self.validate_all(
"SELECT * FROM testtable SAMPLE BLOCK (0.012) REPEATABLE (99992)",
write={
- "snowflake": "SELECT * FROM testtable TABLESAMPLE BLOCK (0.012) SEED (99992)",
+ "snowflake": "SELECT * FROM testtable SAMPLE BLOCK (0.012) SEED (99992)",
},
)
diff --git a/tests/dialects/test_spark.py b/tests/dialects/test_spark.py
index bfaed53..be03b4e 100644
--- a/tests/dialects/test_spark.py
+++ b/tests/dialects/test_spark.py
@@ -215,40 +215,45 @@ TBLPROPERTIES (
self.validate_identity("SPLIT(str, pattern, lim)")
self.validate_all(
- "BOOLEAN(x)",
- write={
- "": "CAST(x AS BOOLEAN)",
- "spark": "CAST(x AS BOOLEAN)",
+ "SELECT * FROM produce PIVOT(SUM(produce.sales) FOR quarter IN ('Q1', 'Q2'))",
+ read={
+ "snowflake": "SELECT * FROM produce PIVOT (SUM(produce.sales) FOR produce.quarter IN ('Q1', 'Q2'))",
},
)
self.validate_all(
- "INT(x)",
- write={
- "": "CAST(x AS INT)",
- "spark": "CAST(x AS INT)",
- },
- )
- self.validate_all(
- "STRING(x)",
- write={
- "": "CAST(x AS TEXT)",
- "spark": "CAST(x AS STRING)",
+ "SELECT * FROM produce AS p PIVOT(SUM(p.sales) AS sales FOR quarter IN ('Q1' AS Q1, 'Q2' AS Q1))",
+ read={
+ "bigquery": "SELECT * FROM produce AS p PIVOT(SUM(p.sales) AS sales FOR p.quarter IN ('Q1' AS Q1, 'Q2' AS Q1))",
},
)
self.validate_all(
- "DATE(x)",
+ "SELECT DATEDIFF(MONTH, '2020-01-01', '2020-03-05')",
write={
- "": "CAST(x AS DATE)",
- "spark": "CAST(x AS DATE)",
+ "databricks": "SELECT DATEDIFF(MONTH, TO_DATE('2020-01-01'), TO_DATE('2020-03-05'))",
+ "hive": "SELECT MONTHS_BETWEEN(TO_DATE('2020-03-05'), TO_DATE('2020-01-01'))",
+ "presto": "SELECT DATE_DIFF('MONTH', CAST(SUBSTR(CAST('2020-01-01' AS VARCHAR), 1, 10) AS DATE), CAST(SUBSTR(CAST('2020-03-05' AS VARCHAR), 1, 10) AS DATE))",
+ "spark": "SELECT DATEDIFF(MONTH, TO_DATE('2020-01-01'), TO_DATE('2020-03-05'))",
+ "spark2": "SELECT MONTHS_BETWEEN(TO_DATE('2020-03-05'), TO_DATE('2020-01-01'))",
+ "trino": "SELECT DATE_DIFF('MONTH', CAST(SUBSTR(CAST('2020-01-01' AS VARCHAR), 1, 10) AS DATE), CAST(SUBSTR(CAST('2020-03-05' AS VARCHAR), 1, 10) AS DATE))",
},
)
+
+ for data_type in ("BOOLEAN", "DATE", "DOUBLE", "FLOAT", "INT", "TIMESTAMP"):
+ self.validate_all(
+ f"{data_type}(x)",
+ write={
+ "": f"CAST(x AS {data_type})",
+ "spark": f"CAST(x AS {data_type})",
+ },
+ )
self.validate_all(
- "TIMESTAMP(x)",
+ "STRING(x)",
write={
- "": "CAST(x AS TIMESTAMP)",
- "spark": "CAST(x AS TIMESTAMP)",
+ "": "CAST(x AS TEXT)",
+ "spark": "CAST(x AS STRING)",
},
)
+
self.validate_all(
"CAST(x AS TIMESTAMP)", read={"trino": "CAST(x AS TIMESTAMP(6) WITH TIME ZONE)"}
)
diff --git a/tests/dialects/test_starrocks.py b/tests/dialects/test_starrocks.py
index b33231c..96e20da 100644
--- a/tests/dialects/test_starrocks.py
+++ b/tests/dialects/test_starrocks.py
@@ -10,3 +10,11 @@ class TestMySQL(Validator):
def test_time(self):
self.validate_identity("TIMESTAMP('2022-01-01')")
+
+ def test_regex(self):
+ self.validate_all(
+ "SELECT REGEXP_LIKE(abc, '%foo%')",
+ write={
+ "starrocks": "SELECT REGEXP(abc, '%foo%')",
+ },
+ )
diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py
index b6e893c..3a3ac73 100644
--- a/tests/dialects/test_tsql.py
+++ b/tests/dialects/test_tsql.py
@@ -485,26 +485,30 @@ WHERE
def test_date_diff(self):
self.validate_identity("SELECT DATEDIFF(year, '2020/01/01', '2021/01/01')")
+
self.validate_all(
"SELECT DATEDIFF(year, '2020/01/01', '2021/01/01')",
write={
"tsql": "SELECT DATEDIFF(year, '2020/01/01', '2021/01/01')",
- "spark": "SELECT MONTHS_BETWEEN('2021/01/01', '2020/01/01') / 12",
+ "spark": "SELECT DATEDIFF(year, '2020/01/01', '2021/01/01')",
+ "spark2": "SELECT MONTHS_BETWEEN('2021/01/01', '2020/01/01') / 12",
},
)
self.validate_all(
"SELECT DATEDIFF(mm, 'start','end')",
write={
- "spark": "SELECT MONTHS_BETWEEN('end', 'start')",
- "tsql": "SELECT DATEDIFF(month, 'start', 'end')",
"databricks": "SELECT DATEDIFF(month, 'start', 'end')",
+ "spark2": "SELECT MONTHS_BETWEEN('end', 'start')",
+ "tsql": "SELECT DATEDIFF(month, 'start', 'end')",
},
)
self.validate_all(
"SELECT DATEDIFF(quarter, 'start', 'end')",
write={
- "spark": "SELECT MONTHS_BETWEEN('end', 'start') / 3",
"databricks": "SELECT DATEDIFF(quarter, 'start', 'end')",
+ "spark": "SELECT DATEDIFF(quarter, 'start', 'end')",
+ "spark2": "SELECT MONTHS_BETWEEN('end', 'start') / 3",
+ "tsql": "SELECT DATEDIFF(quarter, 'start', 'end')",
},
)
diff --git a/tests/fixtures/identity.sql b/tests/fixtures/identity.sql
index a08a7a8..ea695c9 100644
--- a/tests/fixtures/identity.sql
+++ b/tests/fixtures/identity.sql
@@ -85,6 +85,7 @@ x IS TRUE
x IS FALSE
x IS TRUE IS TRUE
x LIKE y IS TRUE
+TRIM('a' || 'b')
MAP()
GREATEST(x)
LEAST(y)
@@ -104,6 +105,7 @@ ARRAY(time, foo)
ARRAY(foo, time)
ARRAY(LENGTH(waiter_name) > 0)
ARRAY_CONTAINS(x, 1)
+x.EXTRACT(1)
EXTRACT(x FROM y)
EXTRACT(DATE FROM y)
EXTRACT(WEEK(monday) FROM created_at)
@@ -215,6 +217,7 @@ SELECT COUNT(DISTINCT a, b)
SELECT COUNT(DISTINCT a, b + 1)
SELECT SUM(DISTINCT x)
SELECT SUM(x IGNORE NULLS) AS x
+SELECT COUNT(x RESPECT NULLS)
SELECT TRUNCATE(a, b)
SELECT ARRAY_AGG(DISTINCT x IGNORE NULLS ORDER BY a, b DESC LIMIT 10) AS x
SELECT ARRAY_AGG(STRUCT(x, x AS y) ORDER BY z DESC) AS x
@@ -820,3 +823,6 @@ JSON_OBJECT('x': 1 RETURNING VARBINARY FORMAT JSON ENCODING UTF8)
SELECT if.x
SELECT NEXT VALUE FOR db.schema.sequence_name
SELECT NEXT VALUE FOR db.schema.sequence_name OVER (ORDER BY foo), col
+SELECT PERCENTILE_CONT(x, 0.5) OVER ()
+SELECT PERCENTILE_CONT(x, 0.5 RESPECT NULLS) OVER ()
+SELECT PERCENTILE_CONT(x, 0.5 IGNORE NULLS) OVER ()
diff --git a/tests/fixtures/optimizer/qualify_columns.sql b/tests/fixtures/optimizer/qualify_columns.sql
index 3013bba..f077647 100644
--- a/tests/fixtures/optimizer/qualify_columns.sql
+++ b/tests/fixtures/optimizer/qualify_columns.sql
@@ -4,6 +4,9 @@
SELECT a FROM x;
SELECT x.a AS a FROM x AS x;
+SELECT "a" FROM x;
+SELECT x."a" AS "a" FROM x AS x;
+
# execute: false
SELECT a FROM zz GROUP BY a ORDER BY a;
SELECT zz.a AS a FROM zz AS zz GROUP BY zz.a ORDER BY a;
@@ -212,6 +215,10 @@ SELECT x.a AS a FROM x AS x WHERE x.b IN (SELECT x.b AS b FROM y AS x);
SELECT a FROM x AS i WHERE b IN (SELECT b FROM y AS j WHERE j.b IN (SELECT c FROM y AS k WHERE k.b = j.b));
SELECT i.a AS a FROM x AS i WHERE i.b IN (SELECT j.b AS b FROM y AS j WHERE j.b IN (SELECT k.c AS c FROM y AS k WHERE k.b = j.b));
+# execute: false
+SELECT (SELECT n.a FROM n WHERE n.id = m.id) FROM m AS m;
+SELECT (SELECT n.a AS a FROM n AS n WHERE n.id = m.id) AS _col_0 FROM m AS m;
+
--------------------------------------
-- Expand *
--------------------------------------
diff --git a/tests/fixtures/optimizer/qualify_tables.sql b/tests/fixtures/optimizer/qualify_tables.sql
index 2cea85d..0ad155a 100644
--- a/tests/fixtures/optimizer/qualify_tables.sql
+++ b/tests/fixtures/optimizer/qualify_tables.sql
@@ -15,3 +15,26 @@ WITH a AS (SELECT 1 FROM c.db.z AS z) SELECT 1 FROM a;
SELECT (SELECT y.c FROM y AS y) FROM x;
SELECT (SELECT y.c FROM c.db.y AS y) FROM c.db.x AS x;
+
+-------------------------
+-- Expand join constructs
+-------------------------
+
+-- This is valid in Trino, so we treat the (tbl AS tbl) as a "join construct" per postgres' terminology.
+SELECT * FROM (tbl AS tbl) AS _q_0;
+SELECT * FROM (SELECT * FROM c.db.tbl AS tbl) AS _q_0;
+
+SELECT * FROM ((tbl AS tbl)) AS _q_0;
+SELECT * FROM (SELECT * FROM c.db.tbl AS tbl) AS _q_0;
+
+SELECT * FROM (((tbl AS tbl))) AS _q_0;
+SELECT * FROM (SELECT * FROM c.db.tbl AS tbl) AS _q_0;
+
+SELECT * FROM (tbl1 AS tbl1 JOIN tbl2 AS tbl2 ON id1 = id2 JOIN tbl3 AS tbl3 ON id1 = id3) AS _q_0;
+SELECT * FROM (SELECT * FROM c.db.tbl1 AS tbl1 JOIN c.db.tbl2 AS tbl2 ON id1 = id2 JOIN c.db.tbl3 AS tbl3 ON id1 = id3) AS _q_0;
+
+SELECT * FROM ((tbl1 AS tbl1 JOIN tbl2 AS tbl2 ON id1 = id2 JOIN tbl3 AS tbl3 ON id1 = id3)) AS _q_0;
+SELECT * FROM (SELECT * FROM c.db.tbl1 AS tbl1 JOIN c.db.tbl2 AS tbl2 ON id1 = id2 JOIN c.db.tbl3 AS tbl3 ON id1 = id3) AS _q_0;
+
+SELECT * FROM (tbl1 AS tbl1 JOIN (tbl2 AS tbl2 JOIN tbl3 AS tbl3 ON id2 = id3) AS _q_0 ON id1 = id3) AS _q_1;
+SELECT * FROM (SELECT * FROM c.db.tbl1 AS tbl1 JOIN (SELECT * FROM c.db.tbl2 AS tbl2 JOIN c.db.tbl3 AS tbl3 ON id2 = id3) AS _q_0 ON id1 = id3) AS _q_1;
diff --git a/tests/fixtures/optimizer/tpc-ds/tpc-ds.sql b/tests/fixtures/optimizer/tpc-ds/tpc-ds.sql
index 9168508..9908756 100644
--- a/tests/fixtures/optimizer/tpc-ds/tpc-ds.sql
+++ b/tests/fixtures/optimizer/tpc-ds/tpc-ds.sql
@@ -6386,6 +6386,14 @@ WITH "tmp1" AS (
OR "item"."i_class" IN ('personal', 'portable', 'reference', 'self-help')
)
AND (
+ "item"."i_brand" IN ('scholaramalgamalg #14', 'scholaramalgamalg #7', 'exportiunivamalg #9', 'scholaramalgamalg #9')
+ OR "item"."i_category" IN ('Women', 'Music', 'Men')
+ )
+ AND (
+ "item"."i_brand" IN ('scholaramalgamalg #14', 'scholaramalgamalg #7', 'exportiunivamalg #9', 'scholaramalgamalg #9')
+ OR "item"."i_class" IN ('accessories', 'classical', 'fragrances', 'pants')
+ )
+ AND (
"item"."i_category" IN ('Books', 'Children', 'Electronics')
OR "item"."i_category" IN ('Women', 'Music', 'Men')
)
@@ -7590,6 +7598,14 @@ WITH "tmp1" AS (
OR "item"."i_class" IN ('personal', 'portable', 'reference', 'self-help')
)
AND (
+ "item"."i_brand" IN ('scholaramalgamalg #14', 'scholaramalgamalg #7', 'exportiunivamalg #9', 'scholaramalgamalg #9')
+ OR "item"."i_category" IN ('Women', 'Music', 'Men')
+ )
+ AND (
+ "item"."i_brand" IN ('scholaramalgamalg #14', 'scholaramalgamalg #7', 'exportiunivamalg #9', 'scholaramalgamalg #9')
+ OR "item"."i_class" IN ('accessories', 'classical', 'fragrances', 'pants')
+ )
+ AND (
"item"."i_category" IN ('Books', 'Children', 'Electronics')
OR "item"."i_category" IN ('Women', 'Music', 'Men')
)
diff --git a/tests/test_build.py b/tests/test_build.py
index c4b97ce..509b857 100644
--- a/tests/test_build.py
+++ b/tests/test_build.py
@@ -19,6 +19,11 @@ from sqlglot import (
class TestBuild(unittest.TestCase):
def test_build(self):
x = condition("x")
+ x_plus_one = x + 1
+
+ # Make sure we're not mutating x by changing its parent to be x_plus_one
+ self.assertIsNone(x.parent)
+ self.assertNotEqual(id(x_plus_one.this), id(x))
for expression, sql, *dialect in [
(lambda: x + 1, "x + 1"),
@@ -51,6 +56,7 @@ class TestBuild(unittest.TestCase):
(lambda: x.neq(1), "x <> 1"),
(lambda: x.isin(1, "2"), "x IN (1, '2')"),
(lambda: x.isin(query="select 1"), "x IN (SELECT 1)"),
+ (lambda: x.between(1, 2), "x BETWEEN 1 AND 2"),
(lambda: 1 + x + 2 + 3, "1 + x + 2 + 3"),
(lambda: 1 + x * 2 + 3, "1 + (x * 2) + 3"),
(lambda: x * 1 * 2 + 3, "(x * 1 * 2) + 3"),
@@ -137,10 +143,14 @@ class TestBuild(unittest.TestCase):
"SELECT x, y, z, a FROM tbl GROUP BY x, y, z, a",
),
(
- lambda: select("x").distinct(True).from_("tbl"),
+ lambda: select("x").distinct("a", "b").from_("tbl"),
+ "SELECT DISTINCT ON (a, b) x FROM tbl",
+ ),
+ (
+ lambda: select("x").distinct(distinct=True).from_("tbl"),
"SELECT DISTINCT x FROM tbl",
),
- (lambda: select("x").distinct(False).from_("tbl"), "SELECT x FROM tbl"),
+ (lambda: select("x").distinct(distinct=False).from_("tbl"), "SELECT x FROM tbl"),
(
lambda: select("x").lateral("OUTER explode(y) tbl2 AS z").from_("tbl"),
"SELECT x FROM tbl LATERAL VIEW OUTER EXPLODE(y) tbl2 AS z",
@@ -583,6 +593,11 @@ class TestBuild(unittest.TestCase):
"DELETE FROM tbl WHERE x = 1 RETURNING *",
"postgres",
),
+ (
+ lambda: exp.convert((exp.column("x"), exp.column("y"))).isin((1, 2), (3, 4)),
+ "(x, y) IN ((1, 2), (3, 4))",
+ "postgres",
+ ),
]:
with self.subTest(sql):
self.assertEqual(expression().sql(dialect[0] if dialect else None), sql)
diff --git a/tests/test_expressions.py b/tests/test_expressions.py
index eb0cf56..e7588b5 100644
--- a/tests/test_expressions.py
+++ b/tests/test_expressions.py
@@ -297,6 +297,9 @@ class TestExpressions(unittest.TestCase):
expression = parse_one("SELECT a, b FROM x")
self.assertEqual([s.sql() for s in expression.selects], ["a", "b"])
+ expression = parse_one("(SELECT a, b FROM x)")
+ self.assertEqual([s.sql() for s in expression.selects], ["a", "b"])
+
def test_alias_column_names(self):
expression = parse_one("SELECT * FROM (SELECT * FROM x) AS y")
subquery = expression.find(exp.Subquery)
@@ -761,7 +764,7 @@ FROM foo""",
"t",
{"a": exp.DataType.build("TEXT"), "b": exp.DataType.build("TEXT")},
).sql(),
- "(VALUES (CAST(1 AS TEXT), CAST(2 AS TEXT)), (3, 4)) AS t(a, b)",
+ "(VALUES (1, 2), (3, 4)) AS t(a, b)",
)
with self.assertRaises(ValueError):
exp.values([(1, 2), (3, 4)], columns=["a"])
diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py
index d077570..423cb84 100644
--- a/tests/test_optimizer.py
+++ b/tests/test_optimizer.py
@@ -47,6 +47,7 @@ class TestOptimizer(unittest.TestCase):
@classmethod
def setUpClass(cls):
+ sqlglot.schema = MappingSchema()
cls.conn = duckdb.connect()
cls.conn.execute(
"""
@@ -221,6 +222,12 @@ class TestOptimizer(unittest.TestCase):
self.check_file("pushdown_predicates", optimizer.pushdown_predicates.pushdown_predicates)
def test_expand_laterals(self):
+ # check order of lateral expansion with no schema
+ self.assertEqual(
+ optimizer.optimize("SELECT a + 1 AS d, d + 1 AS e FROM x " "").sql(),
+ 'SELECT "x"."a" + 1 AS "d", "x"."a" + 2 AS "e" FROM "x" AS "x"',
+ )
+
self.check_file(
"expand_laterals",
optimizer.expand_laterals.expand_laterals,
@@ -612,6 +619,12 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
expression = annotate_types(parse_one("CONCAT('A', 'B')"))
self.assertEqual(expression.type.this, exp.DataType.Type.VARCHAR)
+ def test_root_subquery_annotation(self):
+ expression = annotate_types(parse_one("(SELECT 1, 2 FROM x) LIMIT 0"))
+ self.assertIsInstance(expression, exp.Subquery)
+ self.assertEqual(exp.DataType.Type.INT, expression.selects[0].type.this)
+ self.assertEqual(exp.DataType.Type.INT, expression.selects[1].type.this)
+
def test_recursive_cte(self):
query = parse_one(
"""
diff --git a/tests/test_tokens.py b/tests/test_tokens.py
index 987c60b..f70d70e 100644
--- a/tests/test_tokens.py
+++ b/tests/test_tokens.py
@@ -102,3 +102,14 @@ x"""
(TokenType.SEMICOLON, ";"),
],
)
+
+ tokens = tokenizer.tokenize("""'{{ var('x') }}'""")
+ tokens = [(token.token_type, token.text) for token in tokens]
+ self.assertEqual(
+ tokens,
+ [
+ (TokenType.STRING, "{{ var("),
+ (TokenType.VAR, "x"),
+ (TokenType.STRING, ") }}"),
+ ],
+ )
diff --git a/tests/test_transpile.py b/tests/test_transpile.py
index d68f6f8..ad8ec72 100644
--- a/tests/test_transpile.py
+++ b/tests/test_transpile.py
@@ -263,6 +263,18 @@ FROM v""",
"(/* 1 */ 1 ) /* 2 */",
"(1) /* 1 */ /* 2 */",
)
+ self.validate(
+ "select * from t where not a in (23) /*test*/ and b in (14)",
+ "SELECT * FROM t WHERE NOT a IN (23) /* test */ AND b IN (14)",
+ )
+ self.validate(
+ "select * from t where a in (23) /*test*/ and b in (14)",
+ "SELECT * FROM t WHERE a IN (23) /* test */ AND b IN (14)",
+ )
+ self.validate(
+ "select * from t where ((condition = 1)/*test*/)",
+ "SELECT * FROM t WHERE ((condition = 1) /* test */)",
+ )
def test_types(self):
self.validate("INT 1", "CAST(1 AS INT)")
@@ -324,9 +336,6 @@ FROM v""",
)
self.validate("SELECT IF(a > 1, 1) FROM foo", "SELECT CASE WHEN a > 1 THEN 1 END FROM foo")
- def test_ignore_nulls(self):
- self.validate("SELECT COUNT(x RESPECT NULLS)", "SELECT COUNT(x)")
-
def test_with(self):
self.validate(
"WITH a AS (SELECT 1) WITH b AS (SELECT 2) SELECT *",
@@ -482,7 +491,7 @@ FROM v""",
self.validate("UNIX_TO_STR(123, 'y')", "FROM_UNIXTIME(123, 'y')", write="spark")
self.validate(
"UNIX_TO_TIME(123)",
- "FROM_UNIXTIME(123)",
+ "CAST(FROM_UNIXTIME(123) AS TIMESTAMP)",
write="spark",
)
self.validate(