summaryrefslogtreecommitdiffstats
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/dataframe/unit/test_functions.py3
-rw-r--r--tests/dataframe/unit/test_session.py3
-rw-r--r--tests/dialects/test_dialect.py11
-rw-r--r--tests/dialects/test_snowflake.py6
-rw-r--r--tests/dialects/test_spark.py7
-rw-r--r--tests/dialects/test_teradata.py2
-rw-r--r--tests/fixtures/identity.sql4
-rw-r--r--tests/test_diff.py29
-rw-r--r--tests/test_expressions.py38
-rw-r--r--tests/test_serde.py6
-rw-r--r--tests/test_tokens.py13
-rw-r--r--tests/test_transpile.py5
12 files changed, 119 insertions, 8 deletions
diff --git a/tests/dataframe/unit/test_functions.py b/tests/dataframe/unit/test_functions.py
index f155065..d9a32c4 100644
--- a/tests/dataframe/unit/test_functions.py
+++ b/tests/dataframe/unit/test_functions.py
@@ -2,8 +2,7 @@ import datetime
import inspect
import unittest
-from sqlglot import expressions as exp
-from sqlglot import parse_one
+from sqlglot import expressions as exp, parse_one
from sqlglot.dataframe.sql import functions as SF
from sqlglot.errors import ErrorLevel
diff --git a/tests/dataframe/unit/test_session.py b/tests/dataframe/unit/test_session.py
index f5b79fd..7da0833 100644
--- a/tests/dataframe/unit/test_session.py
+++ b/tests/dataframe/unit/test_session.py
@@ -1,8 +1,7 @@
from unittest import mock
import sqlglot
-from sqlglot.dataframe.sql import functions as F
-from sqlglot.dataframe.sql import types
+from sqlglot.dataframe.sql import functions as F, types
from sqlglot.dataframe.sql.session import SparkSession
from sqlglot.schema import MappingSchema
from tests.dataframe.unit.dataframe_sql_validator import DataFrameSQLValidator
diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py
index 685dea4..3186390 100644
--- a/tests/dialects/test_dialect.py
+++ b/tests/dialects/test_dialect.py
@@ -285,6 +285,10 @@ class TestDialect(Validator):
read={"oracle": "CAST(a AS NUMBER)"},
write={"oracle": "CAST(a AS NUMBER)"},
)
+ self.validate_all(
+ "CAST('127.0.0.1/32' AS INET)",
+ read={"postgres": "INET '127.0.0.1/32'"},
+ )
def test_if_null(self):
self.validate_all(
@@ -509,7 +513,7 @@ class TestDialect(Validator):
"starrocks": "DATE_ADD(x, INTERVAL 1 DAY)",
},
write={
- "bigquery": "DATE_ADD(x, INTERVAL 1 'day')",
+ "bigquery": "DATE_ADD(x, INTERVAL 1 DAY)",
"drill": "DATE_ADD(x, INTERVAL 1 DAY)",
"duckdb": "x + INTERVAL 1 day",
"hive": "DATE_ADD(x, 1)",
@@ -526,7 +530,7 @@ class TestDialect(Validator):
self.validate_all(
"DATE_ADD(x, 1)",
write={
- "bigquery": "DATE_ADD(x, INTERVAL 1 'day')",
+ "bigquery": "DATE_ADD(x, INTERVAL 1 DAY)",
"drill": "DATE_ADD(x, INTERVAL 1 DAY)",
"duckdb": "x + INTERVAL 1 DAY",
"hive": "DATE_ADD(x, 1)",
@@ -540,6 +544,7 @@ class TestDialect(Validator):
"DATE_TRUNC('day', x)",
write={
"mysql": "DATE(x)",
+ "snowflake": "DATE_TRUNC('day', x)",
},
)
self.validate_all(
@@ -576,6 +581,7 @@ class TestDialect(Validator):
"DATE_TRUNC('year', x)",
read={
"bigquery": "DATE_TRUNC(x, year)",
+ "snowflake": "DATE_TRUNC(year, x)",
"starrocks": "DATE_TRUNC('year', x)",
"spark": "TRUNC(x, 'year')",
},
@@ -583,6 +589,7 @@ class TestDialect(Validator):
"bigquery": "DATE_TRUNC(x, year)",
"mysql": "STR_TO_DATE(CONCAT(YEAR(x), ' 1 1'), '%Y %c %e')",
"postgres": "DATE_TRUNC('year', x)",
+ "snowflake": "DATE_TRUNC('year', x)",
"starrocks": "DATE_TRUNC('year', x)",
"spark": "TRUNC(x, 'year')",
},
diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py
index 9e22527..a934c78 100644
--- a/tests/dialects/test_snowflake.py
+++ b/tests/dialects/test_snowflake.py
@@ -397,6 +397,12 @@ class TestSnowflake(Validator):
},
)
+ self.validate_all(
+ "CREATE TABLE a (b INT)",
+ read={"teradata": "CREATE MULTISET TABLE a (b INT)"},
+ write={"snowflake": "CREATE TABLE a (b INT)"},
+ )
+
def test_user_defined_functions(self):
self.validate_all(
"CREATE FUNCTION a(x DATE, y BIGINT) RETURNS ARRAY LANGUAGE JAVASCRIPT AS $$ SELECT 1 $$",
diff --git a/tests/dialects/test_spark.py b/tests/dialects/test_spark.py
index be74a27..9328eaa 100644
--- a/tests/dialects/test_spark.py
+++ b/tests/dialects/test_spark.py
@@ -214,6 +214,13 @@ TBLPROPERTIES (
self.validate_identity("TRIM(TRAILING 'SL' FROM 'SSparkSQLS')")
self.validate_all(
+ "SELECT DATE_ADD(my_date_column, 1)",
+ write={
+ "spark": "SELECT DATE_ADD(my_date_column, 1)",
+ "bigquery": "SELECT DATE_ADD(my_date_column, INTERVAL 1 DAY)",
+ },
+ )
+ self.validate_all(
"AGGREGATE(my_arr, 0, (acc, x) -> acc + x, s -> s * 2)",
write={
"trino": "REDUCE(my_arr, 0, (acc, x) -> acc + x, s -> s * 2)",
diff --git a/tests/dialects/test_teradata.py b/tests/dialects/test_teradata.py
index ab87eef..dd251ab 100644
--- a/tests/dialects/test_teradata.py
+++ b/tests/dialects/test_teradata.py
@@ -35,6 +35,8 @@ class TestTeradata(Validator):
write={"teradata": "SELECT a FROM b"},
)
+ self.validate_identity("CREATE VOLATILE TABLE a (b INT)")
+
def test_insert(self):
self.validate_all(
"INS INTO x SELECT * FROM y", write={"teradata": "INSERT INTO x SELECT * FROM y"}
diff --git a/tests/fixtures/identity.sql b/tests/fixtures/identity.sql
index 7c4ec8e..5e2260c 100644
--- a/tests/fixtures/identity.sql
+++ b/tests/fixtures/identity.sql
@@ -305,6 +305,7 @@ SELECT a FROM test TABLESAMPLE(100 ROWS)
SELECT a FROM test TABLESAMPLE BERNOULLI (50)
SELECT a FROM test TABLESAMPLE SYSTEM (75)
SELECT a FROM test PIVOT(SUM(x) FOR y IN ('z', 'q'))
+SELECT 1 FROM a.b.table1 AS t UNPIVOT((c3) FOR c4 IN (a, b))
SELECT a FROM test PIVOT(SOMEAGG(x, y, z) FOR q IN (1))
SELECT a FROM test PIVOT(SUM(x) FOR y IN ('z', 'q')) PIVOT(MAX(b) FOR c IN ('d'))
SELECT a FROM (SELECT a, b FROM test) PIVOT(SUM(x) FOR y IN ('z', 'q'))
@@ -557,10 +558,11 @@ CREATE TABLE a, BEFORE JOURNAL, AFTER JOURNAL, FREESPACE=1, DEFAULT DATABLOCKSIZ
CREATE TABLE a, DUAL JOURNAL, DUAL AFTER JOURNAL, MERGEBLOCKRATIO=1 PERCENT, DATABLOCKSIZE=10 KILOBYTES (a INT)
CREATE TABLE a, DUAL BEFORE JOURNAL, LOCAL AFTER JOURNAL, MAXIMUM DATABLOCKSIZE, BLOCKCOMPRESSION=AUTOTEMP(c1 INT) (a INT)
CREATE SET GLOBAL TEMPORARY TABLE a, NO BEFORE JOURNAL, NO AFTER JOURNAL, MINIMUM DATABLOCKSIZE, BLOCKCOMPRESSION=NEVER (a INT)
-CREATE MULTISET VOLATILE TABLE a, NOT LOCAL AFTER JOURNAL, FREESPACE=1 PERCENT, DATABLOCKSIZE=10 BYTES, WITH NO CONCURRENT ISOLATED LOADING FOR ALL (a INT)
+CREATE MULTISET TABLE a, NOT LOCAL AFTER JOURNAL, FREESPACE=1 PERCENT, DATABLOCKSIZE=10 BYTES, WITH NO CONCURRENT ISOLATED LOADING FOR ALL (a INT)
CREATE ALGORITHM=UNDEFINED DEFINER=foo@% SQL SECURITY DEFINER VIEW a AS (SELECT a FROM b)
CREATE TEMPORARY TABLE x AS SELECT a FROM d
CREATE TEMPORARY TABLE IF NOT EXISTS x AS SELECT a FROM d
+CREATE TABLE a (b INT) ON COMMIT PRESERVE ROWS
CREATE VIEW x AS SELECT a FROM b
CREATE VIEW IF NOT EXISTS x AS SELECT a FROM b
CREATE VIEW z (a, b COMMENT 'b', c COMMENT 'c') AS SELECT a, b, c FROM d
diff --git a/tests/test_diff.py b/tests/test_diff.py
index cbd53b3..372af70 100644
--- a/tests/test_diff.py
+++ b/tests/test_diff.py
@@ -1,6 +1,6 @@
import unittest
-from sqlglot import parse_one
+from sqlglot import exp, parse_one
from sqlglot.diff import Insert, Keep, Move, Remove, Update, diff
from sqlglot.expressions import Join, to_identifier
@@ -128,6 +128,33 @@ class TestDiff(unittest.TestCase):
],
)
+ def test_pre_matchings(self):
+ expr_src = parse_one("SELECT 1")
+ expr_tgt = parse_one("SELECT 1, 2, 3, 4")
+
+ self._validate_delta_only(
+ diff(expr_src, expr_tgt),
+ [
+ Remove(expr_src),
+ Insert(expr_tgt),
+ Insert(exp.Literal.number(2)),
+ Insert(exp.Literal.number(3)),
+ Insert(exp.Literal.number(4)),
+ ],
+ )
+
+ self._validate_delta_only(
+ diff(expr_src, expr_tgt, matchings=[(expr_src, expr_tgt)]),
+ [
+ Insert(exp.Literal.number(2)),
+ Insert(exp.Literal.number(3)),
+ Insert(exp.Literal.number(4)),
+ ],
+ )
+
+ with self.assertRaises(ValueError):
+ diff(expr_src, expr_tgt, matchings=[(expr_src, expr_tgt), (expr_src, expr_tgt)])
+
def _validate_delta_only(self, actual_diff, expected_delta):
actual_delta = _delta_only(actual_diff)
self.assertEqual(set(actual_delta), set(expected_delta))
diff --git a/tests/test_expressions.py b/tests/test_expressions.py
index 8b74fe1..caa419e 100644
--- a/tests/test_expressions.py
+++ b/tests/test_expressions.py
@@ -91,6 +91,11 @@ class TestExpressions(unittest.TestCase):
self.assertIsInstance(column.parent_select, exp.Select)
self.assertIsNone(column.find_ancestor(exp.Join))
+ def test_root(self):
+ ast = parse_one("select * from (select a from x)")
+ self.assertIs(ast, ast.root())
+ self.assertIs(ast, ast.find(exp.Column).root())
+
def test_alias_or_name(self):
expression = parse_one(
"SELECT a, b AS B, c + d AS e, *, 'zz', 'zz' AS z FROM foo as bar, baz"
@@ -767,3 +772,36 @@ FROM foo""",
exp.rename_table("t1", "t2").sql(),
"ALTER TABLE t1 RENAME TO t2",
)
+
+ def test_is_star(self):
+ assert parse_one("*").is_star
+ assert parse_one("foo.*").is_star
+ assert parse_one("SELECT * FROM foo").is_star
+ assert parse_one("(SELECT * FROM foo)").is_star
+ assert parse_one("SELECT *, 1 FROM foo").is_star
+ assert parse_one("SELECT foo.* FROM foo").is_star
+ assert parse_one("SELECT * EXCEPT (a, b) FROM foo").is_star
+ assert parse_one("SELECT foo.* EXCEPT (foo.a, foo.b) FROM foo").is_star
+ assert parse_one("SELECT * REPLACE (a AS b, b AS C)").is_star
+ assert parse_one("SELECT * EXCEPT (a, b) REPLACE (a AS b, b AS C)").is_star
+ assert parse_one("SELECT * INTO newevent FROM event").is_star
+ assert parse_one("SELECT * FROM foo UNION SELECT * FROM bar").is_star
+ assert parse_one("SELECT * FROM bla UNION SELECT 1 AS x").is_star
+ assert parse_one("SELECT 1 AS x UNION SELECT * FROM bla").is_star
+ assert parse_one("SELECT 1 AS x UNION SELECT 1 AS x UNION SELECT * FROM foo").is_star
+
+ def test_set_metadata(self):
+ ast = parse_one("SELECT foo.col FROM foo")
+
+ self.assertIsNone(ast._meta)
+
+ # calling ast.meta would lazily instantiate self._meta
+ self.assertEqual(ast.meta, {})
+ self.assertEqual(ast._meta, {})
+
+ ast.meta["some_meta_key"] = "some_meta_value"
+ self.assertEqual(ast.meta.get("some_meta_key"), "some_meta_value")
+ self.assertEqual(ast.meta.get("some_other_meta_key"), None)
+
+ ast.meta["some_other_meta_key"] = "some_other_meta_value"
+ self.assertEqual(ast.meta.get("some_other_meta_key"), "some_other_meta_value")
diff --git a/tests/test_serde.py b/tests/test_serde.py
index 603a155..6b5c989 100644
--- a/tests/test_serde.py
+++ b/tests/test_serde.py
@@ -31,3 +31,9 @@ class TestSerDe(unittest.TestCase):
after = self.dump_load(before)
self.assertEqual(before.type, after.type)
self.assertEqual(before.this.type, after.this.type)
+
+ def test_meta(self):
+ before = parse_one("SELECT * FROM X")
+ before.meta["x"] = 1
+ after = self.dump_load(before)
+ self.assertEqual(before.meta, after.meta)
diff --git a/tests/test_tokens.py b/tests/test_tokens.py
index d30c445..0888555 100644
--- a/tests/test_tokens.py
+++ b/tests/test_tokens.py
@@ -18,6 +18,18 @@ class TestTokens(unittest.TestCase):
for sql, comment in sql_comment:
self.assertEqual(tokenizer.tokenize(sql)[0].comments, comment)
+ def test_token_line(self):
+ tokens = Tokenizer().tokenize(
+ """SELECT /*
+ line break
+ */
+ 'x
+ y',
+ x"""
+ )
+
+ self.assertEqual(tokens[-1].line, 6)
+
def test_jinja(self):
tokenizer = Tokenizer()
@@ -26,6 +38,7 @@ class TestTokens(unittest.TestCase):
SELECT
{{ x }},
{{- x -}},
+ {# it's a comment #}
{% for x in y -%}
a {{+ b }}
{% endfor %};
diff --git a/tests/test_transpile.py b/tests/test_transpile.py
index c0d518d..0463aed 100644
--- a/tests/test_transpile.py
+++ b/tests/test_transpile.py
@@ -28,6 +28,11 @@ class TestTranspile(unittest.TestCase):
self.assertEqual(transpile("SELECT 1 current_datetime")[0], "SELECT 1 AS current_datetime")
self.assertEqual(transpile("SELECT 1 row")[0], "SELECT 1 AS row")
+ self.assertEqual(
+ transpile("SELECT 1 FROM a.b.table1 t UNPIVOT((c3) FOR c4 IN (a, b))")[0],
+ "SELECT 1 FROM a.b.table1 AS t UNPIVOT((c3) FOR c4 IN (a, b))",
+ )
+
for key in ("union", "over", "from", "join"):
with self.subTest(f"alias {key}"):
self.validate(f"SELECT x AS {key}", f"SELECT x AS {key}")