summaryrefslogtreecommitdiffstats
path: root/tests/test_transforms.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_transforms.py')
-rw-r--r--tests/test_transforms.py35
1 files changed, 34 insertions, 1 deletions
diff --git a/tests/test_transforms.py b/tests/test_transforms.py
index 76d63b6..1e85b80 100644
--- a/tests/test_transforms.py
+++ b/tests/test_transforms.py
@@ -3,12 +3,13 @@ import unittest
from sqlglot import parse_one
from sqlglot.transforms import (
eliminate_distinct_on,
+ eliminate_qualify,
remove_precision_parameterized_types,
unalias_group,
)
-class TestTime(unittest.TestCase):
+class TestTransforms(unittest.TestCase):
maxDiff = None
def validate(self, transform, sql, target):
@@ -74,6 +75,38 @@ class TestTime(unittest.TestCase):
'SELECT _row_number FROM (SELECT _row_number, ROW_NUMBER() OVER (PARTITION BY _row_number ORDER BY c DESC) AS _row_number_2 FROM x) WHERE "_row_number_2" = 1',
)
+ def test_eliminate_qualify(self):
+ self.validate(
+ eliminate_qualify,
+ "SELECT i, a + 1 FROM qt QUALIFY ROW_NUMBER() OVER (PARTITION BY p) = 1",
+ "SELECT i, _c FROM (SELECT i, a + 1 AS _c, ROW_NUMBER() OVER (PARTITION BY p) AS _w, p FROM qt) AS _t WHERE _w = 1",
+ )
+ self.validate(
+ eliminate_qualify,
+ "SELECT i FROM qt QUALIFY ROW_NUMBER() OVER (PARTITION BY p ORDER BY o) = 1 AND p = 0",
+ "SELECT i FROM (SELECT i, ROW_NUMBER() OVER (PARTITION BY p ORDER BY o) AS _w, p, o FROM qt) AS _t WHERE _w = 1 AND p = 0",
+ )
+ self.validate(
+ eliminate_qualify,
+ "SELECT i, p, o FROM qt QUALIFY ROW_NUMBER() OVER (PARTITION BY p ORDER BY o) = 1",
+ "SELECT i, p, o FROM (SELECT i, p, o, ROW_NUMBER() OVER (PARTITION BY p ORDER BY o) AS _w FROM qt) AS _t WHERE _w = 1",
+ )
+ self.validate(
+ eliminate_qualify,
+ "SELECT i, p, o, ROW_NUMBER() OVER (PARTITION BY p ORDER BY o) AS row_num FROM qt QUALIFY row_num = 1",
+ "SELECT i, p, o, row_num FROM (SELECT i, p, o, ROW_NUMBER() OVER (PARTITION BY p ORDER BY o) AS row_num FROM qt) AS _t WHERE row_num = 1",
+ )
+ self.validate(
+ eliminate_qualify,
+ "SELECT * FROM qt QUALIFY ROW_NUMBER() OVER (PARTITION BY p ORDER BY o) = 1",
+ "SELECT * FROM (SELECT *, ROW_NUMBER() OVER (PARTITION BY p ORDER BY o) AS _w, p, o FROM qt) AS _t WHERE _w = 1",
+ )
+ self.validate(
+ eliminate_qualify,
+ "SELECT c2, SUM(c3) OVER (PARTITION BY c2) AS r FROM t1 WHERE c3 < 4 GROUP BY c2, c3 HAVING SUM(c1) > 3 QUALIFY r IN (SELECT MIN(c1) FROM test GROUP BY c2 HAVING MIN(c1) > 3)",
+ "SELECT c2, r FROM (SELECT c2, SUM(c3) OVER (PARTITION BY c2) AS r, c1 FROM t1 WHERE c3 < 4 GROUP BY c2, c3 HAVING SUM(c1) > 3) AS _t WHERE r IN (SELECT MIN(c1) FROM test GROUP BY c2 HAVING MIN(c1) > 3)",
+ )
+
def test_remove_precision_parameterized_types(self):
self.validate(
remove_precision_parameterized_types,