diff options
Diffstat (limited to 'tests/dialects/test_postgres.py')
-rw-r--r-- | tests/dialects/test_postgres.py | 93 |
1 files changed, 82 insertions, 11 deletions
diff --git a/tests/dialects/test_postgres.py b/tests/dialects/test_postgres.py index 15dbfd0..e0934d7 100644 --- a/tests/dialects/test_postgres.py +++ b/tests/dialects/test_postgres.py @@ -8,9 +8,7 @@ class TestPostgres(Validator): def test_ddl(self): self.validate_all( "CREATE TABLE products (product_no INT UNIQUE, name TEXT, price DECIMAL)", - write={ - "postgres": "CREATE TABLE products (product_no INT UNIQUE, name TEXT, price DECIMAL)" - }, + write={"postgres": "CREATE TABLE products (product_no INT UNIQUE, name TEXT, price DECIMAL)"}, ) self.validate_all( "CREATE TABLE products (product_no INT CONSTRAINT must_be_different UNIQUE, name TEXT CONSTRAINT present NOT NULL, price DECIMAL)", @@ -42,11 +40,17 @@ class TestPostgres(Validator): " CONSTRAINT valid_discount CHECK (price > discounted_price))" }, ) + self.validate_all( + "CREATE TABLE products (x INT GENERATED BY DEFAULT AS IDENTITY)", + write={"postgres": "CREATE TABLE products (x INT GENERATED BY DEFAULT AS IDENTITY)"}, + ) + self.validate_all( + "CREATE TABLE products (x INT GENERATED ALWAYS AS IDENTITY)", + write={"postgres": "CREATE TABLE products (x INT GENERATED ALWAYS AS IDENTITY)"}, + ) with self.assertRaises(ParseError): - transpile( - "CREATE TABLE products (price DECIMAL CHECK price > 0)", read="postgres" - ) + transpile("CREATE TABLE products (price DECIMAL CHECK price > 0)", read="postgres") with self.assertRaises(ParseError): transpile( "CREATE TABLE products (price DECIMAL, CHECK price > 1)", @@ -54,11 +58,16 @@ class TestPostgres(Validator): ) def test_postgres(self): - self.validate_all( - "CREATE TABLE x (a INT SERIAL)", - read={"sqlite": "CREATE TABLE x (a INTEGER AUTOINCREMENT)"}, - write={"sqlite": "CREATE TABLE x (a INTEGER AUTOINCREMENT)"}, - ) + self.validate_identity("SELECT CASE WHEN SUBSTRING('abcdefg') IN ('ab') THEN 1 ELSE 0 END") + self.validate_identity("SELECT CASE WHEN SUBSTRING('abcdefg' FROM 1) IN ('ab') THEN 1 ELSE 0 END") + self.validate_identity("SELECT CASE WHEN SUBSTRING('abcdefg' FROM 1 FOR 2) IN ('ab') THEN 1 ELSE 0 END") + self.validate_identity('SELECT * FROM "x" WHERE SUBSTRING("x"."foo" FROM 1 FOR 2) IN (\'mas\')') + self.validate_identity("SELECT * FROM x WHERE SUBSTRING('Thomas' FROM '...$') IN ('mas')") + self.validate_identity("SELECT * FROM x WHERE SUBSTRING('Thomas' FROM '%#\"o_a#\"_' FOR '#') IN ('mas')") + self.validate_identity("SELECT SUBSTRING('bla' + 'foo' || 'bar' FROM 3 - 1 + 5 FOR 4 + SOME_FUNC(arg1, arg2))") + self.validate_identity("SELECT TRIM(' X' FROM ' XXX ')") + self.validate_identity("SELECT TRIM(LEADING 'bla' FROM ' XXX ' COLLATE utf8_bin)") + self.validate_all( "CREATE TABLE x (a UUID, b BYTEA)", write={ @@ -91,3 +100,65 @@ class TestPostgres(Validator): "spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname NULLS LAST", }, ) + self.validate_all( + "SELECT CASE WHEN SUBSTRING('abcdefg' FROM 1 FOR 2) IN ('ab') THEN 1 ELSE 0 END", + write={ + "hive": "SELECT CASE WHEN SUBSTRING('abcdefg', 1, 2) IN ('ab') THEN 1 ELSE 0 END", + "spark": "SELECT CASE WHEN SUBSTRING('abcdefg', 1, 2) IN ('ab') THEN 1 ELSE 0 END", + }, + ) + self.validate_all( + "SELECT * FROM x WHERE SUBSTRING(col1 FROM 3 + LENGTH(col1) - 10 FOR 10) IN (col2)", + write={ + "hive": "SELECT * FROM x WHERE SUBSTRING(col1, 3 + LENGTH(col1) - 10, 10) IN (col2)", + "spark": "SELECT * FROM x WHERE SUBSTRING(col1, 3 + LENGTH(col1) - 10, 10) IN (col2)", + }, + ) + self.validate_all( + "SELECT SUBSTRING(CAST(2022 AS CHAR(4)) || LPAD(CAST(3 AS CHAR(2)), 2, '0') FROM 3 FOR 4)", + read={ + "postgres": "SELECT SUBSTRING(2022::CHAR(4) || LPAD(3::CHAR(2), 2, '0') FROM 3 FOR 4)", + }, + ) + self.validate_all( + "SELECT TRIM(BOTH ' XXX ')", + write={ + "mysql": "SELECT TRIM(' XXX ')", + "postgres": "SELECT TRIM(' XXX ')", + "hive": "SELECT TRIM(' XXX ')", + }, + ) + self.validate_all( + "TRIM(LEADING FROM ' XXX ')", + write={ + "mysql": "LTRIM(' XXX ')", + "postgres": "LTRIM(' XXX ')", + "hive": "LTRIM(' XXX ')", + "presto": "LTRIM(' XXX ')", + }, + ) + self.validate_all( + "TRIM(TRAILING FROM ' XXX ')", + write={ + "mysql": "RTRIM(' XXX ')", + "postgres": "RTRIM(' XXX ')", + "hive": "RTRIM(' XXX ')", + "presto": "RTRIM(' XXX ')", + }, + ) + self.validate_all( + "SELECT * FROM foo, LATERAL (SELECT * FROM bar WHERE bar.id = foo.bar_id) AS ss", + read={"postgres": "SELECT * FROM foo, LATERAL (SELECT * FROM bar WHERE bar.id = foo.bar_id) AS ss"}, + ) + self.validate_all( + "SELECT m.name FROM manufacturers AS m LEFT JOIN LATERAL GET_PRODUCT_NAMES(m.id) AS pname ON TRUE WHERE pname IS NULL", + read={ + "postgres": "SELECT m.name FROM manufacturers AS m LEFT JOIN LATERAL GET_PRODUCT_NAMES(m.id) AS pname ON TRUE WHERE pname IS NULL", + }, + ) + self.validate_all( + "SELECT p1.id, p2.id, v1, v2 FROM polygons AS p1, polygons AS p2, LATERAL VERTICES(p1.poly) v1, LATERAL VERTICES(p2.poly) v2 WHERE (v1 <-> v2) < 10 AND p1.id <> p2.id", + read={ + "postgres": "SELECT p1.id, p2.id, v1, v2 FROM polygons p1, polygons p2, LATERAL VERTICES(p1.poly) v1, LATERAL VERTICES(p2.poly) v2 WHERE (v1 <-> v2) < 10 AND p1.id != p2.id", + }, + ) |