diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/dataframe/unit/test_functions.py | 3 | ||||
-rw-r--r-- | tests/dataframe/unit/test_session.py | 3 | ||||
-rw-r--r-- | tests/dialects/test_dialect.py | 11 | ||||
-rw-r--r-- | tests/dialects/test_snowflake.py | 6 | ||||
-rw-r--r-- | tests/dialects/test_spark.py | 7 | ||||
-rw-r--r-- | tests/dialects/test_teradata.py | 2 | ||||
-rw-r--r-- | tests/fixtures/identity.sql | 4 | ||||
-rw-r--r-- | tests/test_diff.py | 29 | ||||
-rw-r--r-- | tests/test_expressions.py | 38 | ||||
-rw-r--r-- | tests/test_serde.py | 6 | ||||
-rw-r--r-- | tests/test_tokens.py | 13 | ||||
-rw-r--r-- | tests/test_transpile.py | 5 |
12 files changed, 119 insertions, 8 deletions
diff --git a/tests/dataframe/unit/test_functions.py b/tests/dataframe/unit/test_functions.py index f155065..d9a32c4 100644 --- a/tests/dataframe/unit/test_functions.py +++ b/tests/dataframe/unit/test_functions.py @@ -2,8 +2,7 @@ import datetime import inspect import unittest -from sqlglot import expressions as exp -from sqlglot import parse_one +from sqlglot import expressions as exp, parse_one from sqlglot.dataframe.sql import functions as SF from sqlglot.errors import ErrorLevel diff --git a/tests/dataframe/unit/test_session.py b/tests/dataframe/unit/test_session.py index f5b79fd..7da0833 100644 --- a/tests/dataframe/unit/test_session.py +++ b/tests/dataframe/unit/test_session.py @@ -1,8 +1,7 @@ from unittest import mock import sqlglot -from sqlglot.dataframe.sql import functions as F -from sqlglot.dataframe.sql import types +from sqlglot.dataframe.sql import functions as F, types from sqlglot.dataframe.sql.session import SparkSession from sqlglot.schema import MappingSchema from tests.dataframe.unit.dataframe_sql_validator import DataFrameSQLValidator diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py index 685dea4..3186390 100644 --- a/tests/dialects/test_dialect.py +++ b/tests/dialects/test_dialect.py @@ -285,6 +285,10 @@ class TestDialect(Validator): read={"oracle": "CAST(a AS NUMBER)"}, write={"oracle": "CAST(a AS NUMBER)"}, ) + self.validate_all( + "CAST('127.0.0.1/32' AS INET)", + read={"postgres": "INET '127.0.0.1/32'"}, + ) def test_if_null(self): self.validate_all( @@ -509,7 +513,7 @@ class TestDialect(Validator): "starrocks": "DATE_ADD(x, INTERVAL 1 DAY)", }, write={ - "bigquery": "DATE_ADD(x, INTERVAL 1 'day')", + "bigquery": "DATE_ADD(x, INTERVAL 1 DAY)", "drill": "DATE_ADD(x, INTERVAL 1 DAY)", "duckdb": "x + INTERVAL 1 day", "hive": "DATE_ADD(x, 1)", @@ -526,7 +530,7 @@ class TestDialect(Validator): self.validate_all( "DATE_ADD(x, 1)", write={ - "bigquery": "DATE_ADD(x, INTERVAL 1 'day')", + "bigquery": "DATE_ADD(x, INTERVAL 1 DAY)", "drill": "DATE_ADD(x, INTERVAL 1 DAY)", "duckdb": "x + INTERVAL 1 DAY", "hive": "DATE_ADD(x, 1)", @@ -540,6 +544,7 @@ class TestDialect(Validator): "DATE_TRUNC('day', x)", write={ "mysql": "DATE(x)", + "snowflake": "DATE_TRUNC('day', x)", }, ) self.validate_all( @@ -576,6 +581,7 @@ class TestDialect(Validator): "DATE_TRUNC('year', x)", read={ "bigquery": "DATE_TRUNC(x, year)", + "snowflake": "DATE_TRUNC(year, x)", "starrocks": "DATE_TRUNC('year', x)", "spark": "TRUNC(x, 'year')", }, @@ -583,6 +589,7 @@ class TestDialect(Validator): "bigquery": "DATE_TRUNC(x, year)", "mysql": "STR_TO_DATE(CONCAT(YEAR(x), ' 1 1'), '%Y %c %e')", "postgres": "DATE_TRUNC('year', x)", + "snowflake": "DATE_TRUNC('year', x)", "starrocks": "DATE_TRUNC('year', x)", "spark": "TRUNC(x, 'year')", }, diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py index 9e22527..a934c78 100644 --- a/tests/dialects/test_snowflake.py +++ b/tests/dialects/test_snowflake.py @@ -397,6 +397,12 @@ class TestSnowflake(Validator): }, ) + self.validate_all( + "CREATE TABLE a (b INT)", + read={"teradata": "CREATE MULTISET TABLE a (b INT)"}, + write={"snowflake": "CREATE TABLE a (b INT)"}, + ) + def test_user_defined_functions(self): self.validate_all( "CREATE FUNCTION a(x DATE, y BIGINT) RETURNS ARRAY LANGUAGE JAVASCRIPT AS $$ SELECT 1 $$", diff --git a/tests/dialects/test_spark.py b/tests/dialects/test_spark.py index be74a27..9328eaa 100644 --- a/tests/dialects/test_spark.py +++ b/tests/dialects/test_spark.py @@ -214,6 +214,13 @@ TBLPROPERTIES ( self.validate_identity("TRIM(TRAILING 'SL' FROM 'SSparkSQLS')") self.validate_all( + "SELECT DATE_ADD(my_date_column, 1)", + write={ + "spark": "SELECT DATE_ADD(my_date_column, 1)", + "bigquery": "SELECT DATE_ADD(my_date_column, INTERVAL 1 DAY)", + }, + ) + self.validate_all( "AGGREGATE(my_arr, 0, (acc, x) -> acc + x, s -> s * 2)", write={ "trino": "REDUCE(my_arr, 0, (acc, x) -> acc + x, s -> s * 2)", diff --git a/tests/dialects/test_teradata.py b/tests/dialects/test_teradata.py index ab87eef..dd251ab 100644 --- a/tests/dialects/test_teradata.py +++ b/tests/dialects/test_teradata.py @@ -35,6 +35,8 @@ class TestTeradata(Validator): write={"teradata": "SELECT a FROM b"}, ) + self.validate_identity("CREATE VOLATILE TABLE a (b INT)") + def test_insert(self): self.validate_all( "INS INTO x SELECT * FROM y", write={"teradata": "INSERT INTO x SELECT * FROM y"} diff --git a/tests/fixtures/identity.sql b/tests/fixtures/identity.sql index 7c4ec8e..5e2260c 100644 --- a/tests/fixtures/identity.sql +++ b/tests/fixtures/identity.sql @@ -305,6 +305,7 @@ SELECT a FROM test TABLESAMPLE(100 ROWS) SELECT a FROM test TABLESAMPLE BERNOULLI (50) SELECT a FROM test TABLESAMPLE SYSTEM (75) SELECT a FROM test PIVOT(SUM(x) FOR y IN ('z', 'q')) +SELECT 1 FROM a.b.table1 AS t UNPIVOT((c3) FOR c4 IN (a, b)) SELECT a FROM test PIVOT(SOMEAGG(x, y, z) FOR q IN (1)) SELECT a FROM test PIVOT(SUM(x) FOR y IN ('z', 'q')) PIVOT(MAX(b) FOR c IN ('d')) SELECT a FROM (SELECT a, b FROM test) PIVOT(SUM(x) FOR y IN ('z', 'q')) @@ -557,10 +558,11 @@ CREATE TABLE a, BEFORE JOURNAL, AFTER JOURNAL, FREESPACE=1, DEFAULT DATABLOCKSIZ CREATE TABLE a, DUAL JOURNAL, DUAL AFTER JOURNAL, MERGEBLOCKRATIO=1 PERCENT, DATABLOCKSIZE=10 KILOBYTES (a INT) CREATE TABLE a, DUAL BEFORE JOURNAL, LOCAL AFTER JOURNAL, MAXIMUM DATABLOCKSIZE, BLOCKCOMPRESSION=AUTOTEMP(c1 INT) (a INT) CREATE SET GLOBAL TEMPORARY TABLE a, NO BEFORE JOURNAL, NO AFTER JOURNAL, MINIMUM DATABLOCKSIZE, BLOCKCOMPRESSION=NEVER (a INT) -CREATE MULTISET VOLATILE TABLE a, NOT LOCAL AFTER JOURNAL, FREESPACE=1 PERCENT, DATABLOCKSIZE=10 BYTES, WITH NO CONCURRENT ISOLATED LOADING FOR ALL (a INT) +CREATE MULTISET TABLE a, NOT LOCAL AFTER JOURNAL, FREESPACE=1 PERCENT, DATABLOCKSIZE=10 BYTES, WITH NO CONCURRENT ISOLATED LOADING FOR ALL (a INT) CREATE ALGORITHM=UNDEFINED DEFINER=foo@% SQL SECURITY DEFINER VIEW a AS (SELECT a FROM b) CREATE TEMPORARY TABLE x AS SELECT a FROM d CREATE TEMPORARY TABLE IF NOT EXISTS x AS SELECT a FROM d +CREATE TABLE a (b INT) ON COMMIT PRESERVE ROWS CREATE VIEW x AS SELECT a FROM b CREATE VIEW IF NOT EXISTS x AS SELECT a FROM b CREATE VIEW z (a, b COMMENT 'b', c COMMENT 'c') AS SELECT a, b, c FROM d diff --git a/tests/test_diff.py b/tests/test_diff.py index cbd53b3..372af70 100644 --- a/tests/test_diff.py +++ b/tests/test_diff.py @@ -1,6 +1,6 @@ import unittest -from sqlglot import parse_one +from sqlglot import exp, parse_one from sqlglot.diff import Insert, Keep, Move, Remove, Update, diff from sqlglot.expressions import Join, to_identifier @@ -128,6 +128,33 @@ class TestDiff(unittest.TestCase): ], ) + def test_pre_matchings(self): + expr_src = parse_one("SELECT 1") + expr_tgt = parse_one("SELECT 1, 2, 3, 4") + + self._validate_delta_only( + diff(expr_src, expr_tgt), + [ + Remove(expr_src), + Insert(expr_tgt), + Insert(exp.Literal.number(2)), + Insert(exp.Literal.number(3)), + Insert(exp.Literal.number(4)), + ], + ) + + self._validate_delta_only( + diff(expr_src, expr_tgt, matchings=[(expr_src, expr_tgt)]), + [ + Insert(exp.Literal.number(2)), + Insert(exp.Literal.number(3)), + Insert(exp.Literal.number(4)), + ], + ) + + with self.assertRaises(ValueError): + diff(expr_src, expr_tgt, matchings=[(expr_src, expr_tgt), (expr_src, expr_tgt)]) + def _validate_delta_only(self, actual_diff, expected_delta): actual_delta = _delta_only(actual_diff) self.assertEqual(set(actual_delta), set(expected_delta)) diff --git a/tests/test_expressions.py b/tests/test_expressions.py index 8b74fe1..caa419e 100644 --- a/tests/test_expressions.py +++ b/tests/test_expressions.py @@ -91,6 +91,11 @@ class TestExpressions(unittest.TestCase): self.assertIsInstance(column.parent_select, exp.Select) self.assertIsNone(column.find_ancestor(exp.Join)) + def test_root(self): + ast = parse_one("select * from (select a from x)") + self.assertIs(ast, ast.root()) + self.assertIs(ast, ast.find(exp.Column).root()) + def test_alias_or_name(self): expression = parse_one( "SELECT a, b AS B, c + d AS e, *, 'zz', 'zz' AS z FROM foo as bar, baz" @@ -767,3 +772,36 @@ FROM foo""", exp.rename_table("t1", "t2").sql(), "ALTER TABLE t1 RENAME TO t2", ) + + def test_is_star(self): + assert parse_one("*").is_star + assert parse_one("foo.*").is_star + assert parse_one("SELECT * FROM foo").is_star + assert parse_one("(SELECT * FROM foo)").is_star + assert parse_one("SELECT *, 1 FROM foo").is_star + assert parse_one("SELECT foo.* FROM foo").is_star + assert parse_one("SELECT * EXCEPT (a, b) FROM foo").is_star + assert parse_one("SELECT foo.* EXCEPT (foo.a, foo.b) FROM foo").is_star + assert parse_one("SELECT * REPLACE (a AS b, b AS C)").is_star + assert parse_one("SELECT * EXCEPT (a, b) REPLACE (a AS b, b AS C)").is_star + assert parse_one("SELECT * INTO newevent FROM event").is_star + assert parse_one("SELECT * FROM foo UNION SELECT * FROM bar").is_star + assert parse_one("SELECT * FROM bla UNION SELECT 1 AS x").is_star + assert parse_one("SELECT 1 AS x UNION SELECT * FROM bla").is_star + assert parse_one("SELECT 1 AS x UNION SELECT 1 AS x UNION SELECT * FROM foo").is_star + + def test_set_metadata(self): + ast = parse_one("SELECT foo.col FROM foo") + + self.assertIsNone(ast._meta) + + # calling ast.meta would lazily instantiate self._meta + self.assertEqual(ast.meta, {}) + self.assertEqual(ast._meta, {}) + + ast.meta["some_meta_key"] = "some_meta_value" + self.assertEqual(ast.meta.get("some_meta_key"), "some_meta_value") + self.assertEqual(ast.meta.get("some_other_meta_key"), None) + + ast.meta["some_other_meta_key"] = "some_other_meta_value" + self.assertEqual(ast.meta.get("some_other_meta_key"), "some_other_meta_value") diff --git a/tests/test_serde.py b/tests/test_serde.py index 603a155..6b5c989 100644 --- a/tests/test_serde.py +++ b/tests/test_serde.py @@ -31,3 +31,9 @@ class TestSerDe(unittest.TestCase): after = self.dump_load(before) self.assertEqual(before.type, after.type) self.assertEqual(before.this.type, after.this.type) + + def test_meta(self): + before = parse_one("SELECT * FROM X") + before.meta["x"] = 1 + after = self.dump_load(before) + self.assertEqual(before.meta, after.meta) diff --git a/tests/test_tokens.py b/tests/test_tokens.py index d30c445..0888555 100644 --- a/tests/test_tokens.py +++ b/tests/test_tokens.py @@ -18,6 +18,18 @@ class TestTokens(unittest.TestCase): for sql, comment in sql_comment: self.assertEqual(tokenizer.tokenize(sql)[0].comments, comment) + def test_token_line(self): + tokens = Tokenizer().tokenize( + """SELECT /* + line break + */ + 'x + y', + x""" + ) + + self.assertEqual(tokens[-1].line, 6) + def test_jinja(self): tokenizer = Tokenizer() @@ -26,6 +38,7 @@ class TestTokens(unittest.TestCase): SELECT {{ x }}, {{- x -}}, + {# it's a comment #} {% for x in y -%} a {{+ b }} {% endfor %}; diff --git a/tests/test_transpile.py b/tests/test_transpile.py index c0d518d..0463aed 100644 --- a/tests/test_transpile.py +++ b/tests/test_transpile.py @@ -28,6 +28,11 @@ class TestTranspile(unittest.TestCase): self.assertEqual(transpile("SELECT 1 current_datetime")[0], "SELECT 1 AS current_datetime") self.assertEqual(transpile("SELECT 1 row")[0], "SELECT 1 AS row") + self.assertEqual( + transpile("SELECT 1 FROM a.b.table1 t UNPIVOT((c3) FOR c4 IN (a, b))")[0], + "SELECT 1 FROM a.b.table1 AS t UNPIVOT((c3) FOR c4 IN (a, b))", + ) + for key in ("union", "over", "from", "join"): with self.subTest(f"alias {key}"): self.validate(f"SELECT x AS {key}", f"SELECT x AS {key}") |