diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2022-11-11 08:54:30 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2022-11-11 08:54:30 +0000 |
commit | 9ebe8c99ba4be74ccebf1b013f4e56ec09e023c1 (patch) | |
tree | 7ab2f39fbb6fd832aeea5cef45b54bfd59ba5ba5 /tests | |
parent | Adding upstream version 9.0.6. (diff) | |
download | sqlglot-9ebe8c99ba4be74ccebf1b013f4e56ec09e023c1.tar.xz sqlglot-9ebe8c99ba4be74ccebf1b013f4e56ec09e023c1.zip |
Adding upstream version 10.0.1.upstream/10.0.1
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to '')
33 files changed, 1443 insertions, 518 deletions
diff --git a/tests/dataframe/integration/dataframe_validator.py b/tests/dataframe/integration/dataframe_validator.py index 4a89c78..16f8922 100644 --- a/tests/dataframe/integration/dataframe_validator.py +++ b/tests/dataframe/integration/dataframe_validator.py @@ -1,9 +1,9 @@ -import sys import typing as t import unittest import warnings import sqlglot +from sqlglot.helper import PYTHON_VERSION from tests.helpers import SKIP_INTEGRATION if t.TYPE_CHECKING: @@ -11,7 +11,8 @@ if t.TYPE_CHECKING: @unittest.skipIf( - SKIP_INTEGRATION or sys.version_info[:2] > (3, 10), "Skipping Integration Tests since `SKIP_INTEGRATION` is set" + SKIP_INTEGRATION or PYTHON_VERSION > (3, 10), + "Skipping Integration Tests since `SKIP_INTEGRATION` is set", ) class DataFrameValidator(unittest.TestCase): spark = None @@ -36,7 +37,12 @@ class DataFrameValidator(unittest.TestCase): # This is for test `test_branching_root_dataframes` config = SparkConf().setAll([("spark.sql.analyzer.failAmbiguousSelfJoin", "false")]) - cls.spark = SparkSession.builder.master("local[*]").appName("Unit-tests").config(conf=config).getOrCreate() + cls.spark = ( + SparkSession.builder.master("local[*]") + .appName("Unit-tests") + .config(conf=config) + .getOrCreate() + ) cls.spark.sparkContext.setLogLevel("ERROR") cls.sqlglot = SqlglotSparkSession() cls.spark_employee_schema = types.StructType( @@ -50,7 +56,9 @@ class DataFrameValidator(unittest.TestCase): ) cls.sqlglot_employee_schema = sqlglotSparkTypes.StructType( [ - sqlglotSparkTypes.StructField("employee_id", sqlglotSparkTypes.IntegerType(), False), + sqlglotSparkTypes.StructField( + "employee_id", sqlglotSparkTypes.IntegerType(), False + ), sqlglotSparkTypes.StructField("fname", sqlglotSparkTypes.StringType(), False), sqlglotSparkTypes.StructField("lname", sqlglotSparkTypes.StringType(), False), sqlglotSparkTypes.StructField("age", sqlglotSparkTypes.IntegerType(), False), @@ -64,8 +72,12 @@ class DataFrameValidator(unittest.TestCase): (4, "Claire", "Littleton", 27, 2), (5, "Hugo", "Reyes", 29, 100), ] - cls.df_employee = cls.spark.createDataFrame(data=employee_data, schema=cls.spark_employee_schema) - cls.dfs_employee = cls.sqlglot.createDataFrame(data=employee_data, schema=cls.sqlglot_employee_schema) + cls.df_employee = cls.spark.createDataFrame( + data=employee_data, schema=cls.spark_employee_schema + ) + cls.dfs_employee = cls.sqlglot.createDataFrame( + data=employee_data, schema=cls.sqlglot_employee_schema + ) cls.df_employee.createOrReplaceTempView("employee") cls.spark_store_schema = types.StructType( @@ -80,7 +92,9 @@ class DataFrameValidator(unittest.TestCase): [ sqlglotSparkTypes.StructField("store_id", sqlglotSparkTypes.IntegerType(), False), sqlglotSparkTypes.StructField("store_name", sqlglotSparkTypes.StringType(), False), - sqlglotSparkTypes.StructField("district_id", sqlglotSparkTypes.IntegerType(), False), + sqlglotSparkTypes.StructField( + "district_id", sqlglotSparkTypes.IntegerType(), False + ), sqlglotSparkTypes.StructField("num_sales", sqlglotSparkTypes.IntegerType(), False), ] ) @@ -89,7 +103,9 @@ class DataFrameValidator(unittest.TestCase): (2, "Arrow", 2, 2000), ] cls.df_store = cls.spark.createDataFrame(data=store_data, schema=cls.spark_store_schema) - cls.dfs_store = cls.sqlglot.createDataFrame(data=store_data, schema=cls.sqlglot_store_schema) + cls.dfs_store = cls.sqlglot.createDataFrame( + data=store_data, schema=cls.sqlglot_store_schema + ) cls.df_store.createOrReplaceTempView("store") cls.spark_district_schema = types.StructType( @@ -101,17 +117,27 @@ class DataFrameValidator(unittest.TestCase): ) cls.sqlglot_district_schema = sqlglotSparkTypes.StructType( [ - sqlglotSparkTypes.StructField("district_id", sqlglotSparkTypes.IntegerType(), False), - sqlglotSparkTypes.StructField("district_name", sqlglotSparkTypes.StringType(), False), - sqlglotSparkTypes.StructField("manager_name", sqlglotSparkTypes.StringType(), False), + sqlglotSparkTypes.StructField( + "district_id", sqlglotSparkTypes.IntegerType(), False + ), + sqlglotSparkTypes.StructField( + "district_name", sqlglotSparkTypes.StringType(), False + ), + sqlglotSparkTypes.StructField( + "manager_name", sqlglotSparkTypes.StringType(), False + ), ] ) district_data = [ (1, "Temple", "Dogen"), (2, "Lighthouse", "Jacob"), ] - cls.df_district = cls.spark.createDataFrame(data=district_data, schema=cls.spark_district_schema) - cls.dfs_district = cls.sqlglot.createDataFrame(data=district_data, schema=cls.sqlglot_district_schema) + cls.df_district = cls.spark.createDataFrame( + data=district_data, schema=cls.spark_district_schema + ) + cls.dfs_district = cls.sqlglot.createDataFrame( + data=district_data, schema=cls.sqlglot_district_schema + ) cls.df_district.createOrReplaceTempView("district") sqlglot.schema.add_table("employee", cls.sqlglot_employee_schema) sqlglot.schema.add_table("store", cls.sqlglot_store_schema) diff --git a/tests/dataframe/integration/test_dataframe.py b/tests/dataframe/integration/test_dataframe.py index c740bec..19e3b89 100644 --- a/tests/dataframe/integration/test_dataframe.py +++ b/tests/dataframe/integration/test_dataframe.py @@ -41,22 +41,32 @@ class TestDataframeFunc(DataFrameValidator): def test_alias_with_select(self): df_employee = self.df_spark_employee.alias("df_employee").select( - self.df_spark_employee["employee_id"], F.col("df_employee.fname"), self.df_spark_employee.lname + self.df_spark_employee["employee_id"], + F.col("df_employee.fname"), + self.df_spark_employee.lname, ) dfs_employee = self.df_sqlglot_employee.alias("dfs_employee").select( - self.df_sqlglot_employee["employee_id"], SF.col("dfs_employee.fname"), self.df_sqlglot_employee.lname + self.df_sqlglot_employee["employee_id"], + SF.col("dfs_employee.fname"), + self.df_sqlglot_employee.lname, ) self.compare_spark_with_sqlglot(df_employee, dfs_employee) def test_case_when_otherwise(self): df = self.df_spark_employee.select( - F.when((F.col("age") >= F.lit(40)) & (F.col("age") <= F.lit(60)), F.lit("between 40 and 60")) + F.when( + (F.col("age") >= F.lit(40)) & (F.col("age") <= F.lit(60)), + F.lit("between 40 and 60"), + ) .when(F.col("age") < F.lit(40), "less than 40") .otherwise("greater than 60") ) dfs = self.df_sqlglot_employee.select( - SF.when((SF.col("age") >= SF.lit(40)) & (SF.col("age") <= SF.lit(60)), SF.lit("between 40 and 60")) + SF.when( + (SF.col("age") >= SF.lit(40)) & (SF.col("age") <= SF.lit(60)), + SF.lit("between 40 and 60"), + ) .when(SF.col("age") < SF.lit(40), "less than 40") .otherwise("greater than 60") ) @@ -65,15 +75,17 @@ class TestDataframeFunc(DataFrameValidator): def test_case_when_no_otherwise(self): df = self.df_spark_employee.select( - F.when((F.col("age") >= F.lit(40)) & (F.col("age") <= F.lit(60)), F.lit("between 40 and 60")).when( - F.col("age") < F.lit(40), "less than 40" - ) + F.when( + (F.col("age") >= F.lit(40)) & (F.col("age") <= F.lit(60)), + F.lit("between 40 and 60"), + ).when(F.col("age") < F.lit(40), "less than 40") ) dfs = self.df_sqlglot_employee.select( - SF.when((SF.col("age") >= SF.lit(40)) & (SF.col("age") <= SF.lit(60)), SF.lit("between 40 and 60")).when( - SF.col("age") < SF.lit(40), "less than 40" - ) + SF.when( + (SF.col("age") >= SF.lit(40)) & (SF.col("age") <= SF.lit(60)), + SF.lit("between 40 and 60"), + ).when(SF.col("age") < SF.lit(40), "less than 40") ) self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) @@ -84,7 +96,9 @@ class TestDataframeFunc(DataFrameValidator): self.compare_spark_with_sqlglot(df_employee, dfs_employee) def test_where_clause_multiple_and(self): - df_employee = self.df_spark_employee.where((F.col("age") == F.lit(37)) & (F.col("fname") == F.lit("Jack"))) + df_employee = self.df_spark_employee.where( + (F.col("age") == F.lit(37)) & (F.col("fname") == F.lit("Jack")) + ) dfs_employee = self.df_sqlglot_employee.where( (SF.col("age") == SF.lit(37)) & (SF.col("fname") == SF.lit("Jack")) ) @@ -106,7 +120,9 @@ class TestDataframeFunc(DataFrameValidator): self.compare_spark_with_sqlglot(df_employee, dfs_employee) def test_where_clause_multiple_or(self): - df_employee = self.df_spark_employee.where((F.col("age") == F.lit(37)) | (F.col("fname") == F.lit("Kate"))) + df_employee = self.df_spark_employee.where( + (F.col("age") == F.lit(37)) | (F.col("fname") == F.lit("Kate")) + ) dfs_employee = self.df_sqlglot_employee.where( (SF.col("age") == SF.lit(37)) | (SF.col("fname") == SF.lit("Kate")) ) @@ -172,28 +188,43 @@ class TestDataframeFunc(DataFrameValidator): dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] == SF.lit(37)) self.compare_spark_with_sqlglot(df_employee, dfs_employee) - df_employee = self.df_spark_employee.where(self.df_spark_employee["age"] % F.lit(5) == F.lit(0)) - dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] % SF.lit(5) == SF.lit(0)) + df_employee = self.df_spark_employee.where( + self.df_spark_employee["age"] % F.lit(5) == F.lit(0) + ) + dfs_employee = self.df_sqlglot_employee.where( + self.df_sqlglot_employee["age"] % SF.lit(5) == SF.lit(0) + ) self.compare_spark_with_sqlglot(df_employee, dfs_employee) - df_employee = self.df_spark_employee.where(self.df_spark_employee["age"] + F.lit(5) > F.lit(28)) - dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] + SF.lit(5) > SF.lit(28)) + df_employee = self.df_spark_employee.where( + self.df_spark_employee["age"] + F.lit(5) > F.lit(28) + ) + dfs_employee = self.df_sqlglot_employee.where( + self.df_sqlglot_employee["age"] + SF.lit(5) > SF.lit(28) + ) self.compare_spark_with_sqlglot(df_employee, dfs_employee) - df_employee = self.df_spark_employee.where(self.df_spark_employee["age"] - F.lit(5) > F.lit(28)) - dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] - SF.lit(5) > SF.lit(28)) + df_employee = self.df_spark_employee.where( + self.df_spark_employee["age"] - F.lit(5) > F.lit(28) + ) + dfs_employee = self.df_sqlglot_employee.where( + self.df_sqlglot_employee["age"] - SF.lit(5) > SF.lit(28) + ) self.compare_spark_with_sqlglot(df_employee, dfs_employee) df_employee = self.df_spark_employee.where( self.df_spark_employee["age"] * F.lit(0.5) == self.df_spark_employee["age"] / F.lit(2) ) dfs_employee = self.df_sqlglot_employee.where( - self.df_sqlglot_employee["age"] * SF.lit(0.5) == self.df_sqlglot_employee["age"] / SF.lit(2) + self.df_sqlglot_employee["age"] * SF.lit(0.5) + == self.df_sqlglot_employee["age"] / SF.lit(2) ) self.compare_spark_with_sqlglot(df_employee, dfs_employee) def test_join_inner(self): - df_joined = self.df_spark_employee.join(self.df_spark_store, on=["store_id"], how="inner").select( + df_joined = self.df_spark_employee.join( + self.df_spark_store, on=["store_id"], how="inner" + ).select( self.df_spark_employee.employee_id, self.df_spark_employee["fname"], F.col("lname"), @@ -202,7 +233,9 @@ class TestDataframeFunc(DataFrameValidator): self.df_spark_store.store_name, self.df_spark_store["num_sales"], ) - dfs_joined = self.df_sqlglot_employee.join(self.df_sqlglot_store, on=["store_id"], how="inner").select( + dfs_joined = self.df_sqlglot_employee.join( + self.df_sqlglot_store, on=["store_id"], how="inner" + ).select( self.df_sqlglot_employee.employee_id, self.df_sqlglot_employee["fname"], SF.col("lname"), @@ -214,17 +247,27 @@ class TestDataframeFunc(DataFrameValidator): self.compare_spark_with_sqlglot(df_joined, dfs_joined) def test_join_inner_no_select(self): - df_joined = self.df_spark_employee.select(F.col("store_id"), F.col("fname"), F.col("lname")).join( - self.df_spark_store.select(F.col("store_id"), F.col("store_name")), on=["store_id"], how="inner" + df_joined = self.df_spark_employee.select( + F.col("store_id"), F.col("fname"), F.col("lname") + ).join( + self.df_spark_store.select(F.col("store_id"), F.col("store_name")), + on=["store_id"], + how="inner", ) - dfs_joined = self.df_sqlglot_employee.select(SF.col("store_id"), SF.col("fname"), SF.col("lname")).join( - self.df_sqlglot_store.select(SF.col("store_id"), SF.col("store_name")), on=["store_id"], how="inner" + dfs_joined = self.df_sqlglot_employee.select( + SF.col("store_id"), SF.col("fname"), SF.col("lname") + ).join( + self.df_sqlglot_store.select(SF.col("store_id"), SF.col("store_name")), + on=["store_id"], + how="inner", ) self.compare_spark_with_sqlglot(df_joined, dfs_joined) def test_join_inner_equality_single(self): df_joined = self.df_spark_employee.join( - self.df_spark_store, on=self.df_spark_employee.store_id == self.df_spark_store.store_id, how="inner" + self.df_spark_store, + on=self.df_spark_employee.store_id == self.df_spark_store.store_id, + how="inner", ).select( self.df_spark_employee.employee_id, self.df_spark_employee["fname"], @@ -235,7 +278,9 @@ class TestDataframeFunc(DataFrameValidator): self.df_spark_store["num_sales"], ) dfs_joined = self.df_sqlglot_employee.join( - self.df_sqlglot_store, on=self.df_sqlglot_employee.store_id == self.df_sqlglot_store.store_id, how="inner" + self.df_sqlglot_store, + on=self.df_sqlglot_employee.store_id == self.df_sqlglot_store.store_id, + how="inner", ).select( self.df_sqlglot_employee.employee_id, self.df_sqlglot_employee["fname"], @@ -343,7 +388,9 @@ class TestDataframeFunc(DataFrameValidator): self.compare_spark_with_sqlglot(df_joined, dfs_joined) def test_join_full_outer(self): - df_joined = self.df_spark_employee.join(self.df_spark_store, on=["store_id"], how="full_outer").select( + df_joined = self.df_spark_employee.join( + self.df_spark_store, on=["store_id"], how="full_outer" + ).select( self.df_spark_employee.employee_id, self.df_spark_employee["fname"], F.col("lname"), @@ -352,7 +399,9 @@ class TestDataframeFunc(DataFrameValidator): self.df_spark_store.store_name, self.df_spark_store["num_sales"], ) - dfs_joined = self.df_sqlglot_employee.join(self.df_sqlglot_store, on=["store_id"], how="full_outer").select( + dfs_joined = self.df_sqlglot_employee.join( + self.df_sqlglot_store, on=["store_id"], how="full_outer" + ).select( self.df_sqlglot_employee.employee_id, self.df_sqlglot_employee["fname"], SF.col("lname"), @@ -365,7 +414,9 @@ class TestDataframeFunc(DataFrameValidator): def test_triple_join(self): df = ( - self.df_employee.join(self.df_store, on=self.df_employee.employee_id == self.df_store.store_id) + self.df_employee.join( + self.df_store, on=self.df_employee.employee_id == self.df_store.store_id + ) .join(self.df_district, on=self.df_store.store_id == self.df_district.district_id) .select( self.df_employee.employee_id, @@ -377,7 +428,9 @@ class TestDataframeFunc(DataFrameValidator): ) ) dfs = ( - self.dfs_employee.join(self.dfs_store, on=self.dfs_employee.employee_id == self.dfs_store.store_id) + self.dfs_employee.join( + self.dfs_store, on=self.dfs_employee.employee_id == self.dfs_store.store_id + ) .join(self.dfs_district, on=self.dfs_store.store_id == self.dfs_district.district_id) .select( self.dfs_employee.employee_id, @@ -391,13 +444,13 @@ class TestDataframeFunc(DataFrameValidator): self.compare_spark_with_sqlglot(df, dfs) def test_join_select_and_select_start(self): - df = self.df_spark_employee.select(F.col("fname"), F.col("lname"), F.col("age"), F.col("store_id")).join( - self.df_spark_store, "store_id", "inner" - ) + df = self.df_spark_employee.select( + F.col("fname"), F.col("lname"), F.col("age"), F.col("store_id") + ).join(self.df_spark_store, "store_id", "inner") - dfs = self.df_sqlglot_employee.select(SF.col("fname"), SF.col("lname"), SF.col("age"), SF.col("store_id")).join( - self.df_sqlglot_store, "store_id", "inner" - ) + dfs = self.df_sqlglot_employee.select( + SF.col("fname"), SF.col("lname"), SF.col("age"), SF.col("store_id") + ).join(self.df_sqlglot_store, "store_id", "inner") self.compare_spark_with_sqlglot(df, dfs) @@ -485,13 +538,17 @@ class TestDataframeFunc(DataFrameValidator): dfs_unioned = ( self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("fname")) .unionAll(self.df_sqlglot_store.select(SF.col("store_id"), SF.col("store_name"))) - .unionAll(self.df_sqlglot_district.select(SF.col("district_id"), SF.col("district_name"))) + .unionAll( + self.df_sqlglot_district.select(SF.col("district_id"), SF.col("district_name")) + ) ) self.compare_spark_with_sqlglot(df_unioned, dfs_unioned) def test_union_by_name(self): - df = self.df_spark_employee.select(F.col("employee_id"), F.col("fname"), F.col("lname")).unionByName( + df = self.df_spark_employee.select( + F.col("employee_id"), F.col("fname"), F.col("lname") + ).unionByName( self.df_spark_store.select( F.col("store_name").alias("lname"), F.col("store_id").alias("employee_id"), @@ -499,7 +556,9 @@ class TestDataframeFunc(DataFrameValidator): ) ) - dfs = self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("fname"), SF.col("lname")).unionByName( + dfs = self.df_sqlglot_employee.select( + SF.col("employee_id"), SF.col("fname"), SF.col("lname") + ).unionByName( self.df_sqlglot_store.select( SF.col("store_name").alias("lname"), SF.col("store_id").alias("employee_id"), @@ -537,10 +596,16 @@ class TestDataframeFunc(DataFrameValidator): self.compare_spark_with_sqlglot(df, dfs) def test_order_by_default(self): - df = self.df_spark_store.groupBy(F.col("district_id")).agg(F.min("num_sales")).orderBy(F.col("district_id")) + df = ( + self.df_spark_store.groupBy(F.col("district_id")) + .agg(F.min("num_sales")) + .orderBy(F.col("district_id")) + ) dfs = ( - self.df_sqlglot_store.groupBy(SF.col("district_id")).agg(SF.min("num_sales")).orderBy(SF.col("district_id")) + self.df_sqlglot_store.groupBy(SF.col("district_id")) + .agg(SF.min("num_sales")) + .orderBy(SF.col("district_id")) ) self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) @@ -594,13 +659,17 @@ class TestDataframeFunc(DataFrameValidator): df = ( self.df_spark_store.groupBy(F.col("district_id")) .agg(F.min("num_sales").alias("total_sales")) - .orderBy(F.when(F.col("district_id") == F.lit(2), F.col("district_id")).asc_nulls_last()) + .orderBy( + F.when(F.col("district_id") == F.lit(2), F.col("district_id")).asc_nulls_last() + ) ) dfs = ( self.df_sqlglot_store.groupBy(SF.col("district_id")) .agg(SF.min("num_sales").alias("total_sales")) - .orderBy(SF.when(SF.col("district_id") == SF.lit(2), SF.col("district_id")).asc_nulls_last()) + .orderBy( + SF.when(SF.col("district_id") == SF.lit(2), SF.col("district_id")).asc_nulls_last() + ) ) self.compare_spark_with_sqlglot(df, dfs) @@ -609,81 +678,87 @@ class TestDataframeFunc(DataFrameValidator): df = ( self.df_spark_store.groupBy(F.col("district_id")) .agg(F.min("num_sales").alias("total_sales")) - .orderBy(F.when(F.col("district_id") == F.lit(1), F.col("district_id")).desc_nulls_first()) + .orderBy( + F.when(F.col("district_id") == F.lit(1), F.col("district_id")).desc_nulls_first() + ) ) dfs = ( self.df_sqlglot_store.groupBy(SF.col("district_id")) .agg(SF.min("num_sales").alias("total_sales")) - .orderBy(SF.when(SF.col("district_id") == SF.lit(1), SF.col("district_id")).desc_nulls_first()) + .orderBy( + SF.when( + SF.col("district_id") == SF.lit(1), SF.col("district_id") + ).desc_nulls_first() + ) ) self.compare_spark_with_sqlglot(df, dfs) def test_intersect(self): - df_employee_duplicate = self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")).union( - self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")) - ) + df_employee_duplicate = self.df_spark_employee.select( + F.col("employee_id"), F.col("store_id") + ).union(self.df_spark_employee.select(F.col("employee_id"), F.col("store_id"))) - df_store_duplicate = self.df_spark_store.select(F.col("store_id"), F.col("district_id")).union( - self.df_spark_store.select(F.col("store_id"), F.col("district_id")) - ) + df_store_duplicate = self.df_spark_store.select( + F.col("store_id"), F.col("district_id") + ).union(self.df_spark_store.select(F.col("store_id"), F.col("district_id"))) df = df_employee_duplicate.intersect(df_store_duplicate) - dfs_employee_duplicate = self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")).union( - self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")) - ) + dfs_employee_duplicate = self.df_sqlglot_employee.select( + SF.col("employee_id"), SF.col("store_id") + ).union(self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id"))) - dfs_store_duplicate = self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")).union( - self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")) - ) + dfs_store_duplicate = self.df_sqlglot_store.select( + SF.col("store_id"), SF.col("district_id") + ).union(self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id"))) dfs = dfs_employee_duplicate.intersect(dfs_store_duplicate) self.compare_spark_with_sqlglot(df, dfs) def test_intersect_all(self): - df_employee_duplicate = self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")).union( - self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")) - ) + df_employee_duplicate = self.df_spark_employee.select( + F.col("employee_id"), F.col("store_id") + ).union(self.df_spark_employee.select(F.col("employee_id"), F.col("store_id"))) - df_store_duplicate = self.df_spark_store.select(F.col("store_id"), F.col("district_id")).union( - self.df_spark_store.select(F.col("store_id"), F.col("district_id")) - ) + df_store_duplicate = self.df_spark_store.select( + F.col("store_id"), F.col("district_id") + ).union(self.df_spark_store.select(F.col("store_id"), F.col("district_id"))) df = df_employee_duplicate.intersectAll(df_store_duplicate) - dfs_employee_duplicate = self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")).union( - self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")) - ) + dfs_employee_duplicate = self.df_sqlglot_employee.select( + SF.col("employee_id"), SF.col("store_id") + ).union(self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id"))) - dfs_store_duplicate = self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")).union( - self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")) - ) + dfs_store_duplicate = self.df_sqlglot_store.select( + SF.col("store_id"), SF.col("district_id") + ).union(self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id"))) dfs = dfs_employee_duplicate.intersectAll(dfs_store_duplicate) self.compare_spark_with_sqlglot(df, dfs) def test_except_all(self): - df_employee_duplicate = self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")).union( - self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")) - ) + df_employee_duplicate = self.df_spark_employee.select( + F.col("employee_id"), F.col("store_id") + ).union(self.df_spark_employee.select(F.col("employee_id"), F.col("store_id"))) - df_store_duplicate = self.df_spark_store.select(F.col("store_id"), F.col("district_id")).union( - self.df_spark_store.select(F.col("store_id"), F.col("district_id")) - ) + df_store_duplicate = self.df_spark_store.select( + F.col("store_id"), F.col("district_id") + ).union(self.df_spark_store.select(F.col("store_id"), F.col("district_id"))) df = df_employee_duplicate.exceptAll(df_store_duplicate) - dfs_employee_duplicate = self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")).union( - self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")) - ) + dfs_employee_duplicate = self.df_sqlglot_employee.select( + SF.col("employee_id"), SF.col("store_id") + ).union(self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id"))) - dfs_store_duplicate = self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")).union( - self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")) - ) + dfs_store_duplicate = self.df_sqlglot_store.select( + SF.col("store_id"), SF.col("district_id") + ).union(self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id"))) dfs = dfs_employee_duplicate.exceptAll(dfs_store_duplicate) @@ -721,7 +796,9 @@ class TestDataframeFunc(DataFrameValidator): self.compare_spark_with_sqlglot(df, dfs) def test_drop_na_default(self): - df = self.df_spark_employee.select(F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")).dropna() + df = self.df_spark_employee.select( + F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age") + ).dropna() dfs = self.df_sqlglot_employee.select( SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age") @@ -746,7 +823,9 @@ class TestDataframeFunc(DataFrameValidator): ).dropna(how="any", thresh=2) dfs = self.df_sqlglot_employee.select( - SF.lit(None), SF.lit(1), SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age") + SF.lit(None), + SF.lit(1), + SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age"), ).dropna(how="any", thresh=2) self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) @@ -757,13 +836,17 @@ class TestDataframeFunc(DataFrameValidator): ).dropna(thresh=1, subset="the_age") dfs = self.df_sqlglot_employee.select( - SF.lit(None), SF.lit(1), SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age") + SF.lit(None), + SF.lit(1), + SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age"), ).dropna(thresh=1, subset="the_age") self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) def test_dropna_na_function(self): - df = self.df_spark_employee.select(F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")).na.drop() + df = self.df_spark_employee.select( + F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age") + ).na.drop() dfs = self.df_sqlglot_employee.select( SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age") @@ -772,7 +855,9 @@ class TestDataframeFunc(DataFrameValidator): self.compare_spark_with_sqlglot(df, dfs) def test_fillna_default(self): - df = self.df_spark_employee.select(F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")).fillna(100) + df = self.df_spark_employee.select( + F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age") + ).fillna(100) dfs = self.df_sqlglot_employee.select( SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age") @@ -798,7 +883,9 @@ class TestDataframeFunc(DataFrameValidator): self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) def test_fillna_na_func(self): - df = self.df_spark_employee.select(F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")).na.fill(100) + df = self.df_spark_employee.select( + F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age") + ).na.fill(100) dfs = self.df_sqlglot_employee.select( SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age") @@ -807,7 +894,9 @@ class TestDataframeFunc(DataFrameValidator): self.compare_spark_with_sqlglot(df, dfs) def test_replace_basic(self): - df = self.df_spark_employee.select(F.col("age"), F.lit(37).alias("test_col")).replace(to_replace=37, value=100) + df = self.df_spark_employee.select(F.col("age"), F.lit(37).alias("test_col")).replace( + to_replace=37, value=100 + ) dfs = self.df_sqlglot_employee.select(SF.col("age"), SF.lit(37).alias("test_col")).replace( to_replace=37, value=100 @@ -827,9 +916,13 @@ class TestDataframeFunc(DataFrameValidator): self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) def test_replace_mapping(self): - df = self.df_spark_employee.select(F.col("age"), F.lit(37).alias("test_col")).replace({37: 100}) + df = self.df_spark_employee.select(F.col("age"), F.lit(37).alias("test_col")).replace( + {37: 100} + ) - dfs = self.df_sqlglot_employee.select(SF.col("age"), SF.lit(37).alias("test_col")).replace({37: 100}) + dfs = self.df_sqlglot_employee.select(SF.col("age"), SF.lit(37).alias("test_col")).replace( + {37: 100} + ) self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) @@ -849,9 +942,9 @@ class TestDataframeFunc(DataFrameValidator): to_replace=37, value=100 ) - dfs = self.df_sqlglot_employee.select(SF.col("age"), SF.lit(37).alias("test_col")).na.replace( - to_replace=37, value=100 - ) + dfs = self.df_sqlglot_employee.select( + SF.col("age"), SF.lit(37).alias("test_col") + ).na.replace(to_replace=37, value=100) self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) @@ -881,16 +974,18 @@ class TestDataframeFunc(DataFrameValidator): "first_name", "first_name_again" ) - dfs = self.df_sqlglot_employee.select(SF.col("fname").alias("first_name")).withColumnRenamed( - "first_name", "first_name_again" - ) + dfs = self.df_sqlglot_employee.select( + SF.col("fname").alias("first_name") + ).withColumnRenamed("first_name", "first_name_again") self.compare_spark_with_sqlglot(df, dfs) def test_drop_column_single(self): df = self.df_spark_employee.select(F.col("fname"), F.col("lname"), F.col("age")).drop("age") - dfs = self.df_sqlglot_employee.select(SF.col("fname"), SF.col("lname"), SF.col("age")).drop("age") + dfs = self.df_sqlglot_employee.select(SF.col("fname"), SF.col("lname"), SF.col("age")).drop( + "age" + ) self.compare_spark_with_sqlglot(df, dfs) @@ -906,7 +1001,9 @@ class TestDataframeFunc(DataFrameValidator): df_sqlglot_employee_cols = self.df_sqlglot_employee.select( SF.col("fname"), SF.col("lname"), SF.col("age"), SF.col("store_id") ) - df_sqlglot_store_cols = self.df_sqlglot_store.select(SF.col("store_id"), SF.col("store_name")) + df_sqlglot_store_cols = self.df_sqlglot_store.select( + SF.col("store_id"), SF.col("store_name") + ) dfs = df_sqlglot_employee_cols.join(df_sqlglot_store_cols, on="store_id", how="inner").drop( df_sqlglot_employee_cols.age, ) diff --git a/tests/dataframe/integration/test_session.py b/tests/dataframe/integration/test_session.py index ff1477b..ec50034 100644 --- a/tests/dataframe/integration/test_session.py +++ b/tests/dataframe/integration/test_session.py @@ -23,6 +23,14 @@ class TestSessionFunc(DataFrameValidator): ON e.store_id = s.store_id """ - df = self.spark.sql(query).groupBy(F.col("store_id")).agg(F.countDistinct(F.col("employee_id"))) - dfs = self.sqlglot.sql(query).groupBy(SF.col("store_id")).agg(SF.countDistinct(SF.col("employee_id"))) + df = ( + self.spark.sql(query) + .groupBy(F.col("store_id")) + .agg(F.countDistinct(F.col("employee_id"))) + ) + dfs = ( + self.sqlglot.sql(query) + .groupBy(SF.col("store_id")) + .agg(SF.countDistinct(SF.col("employee_id"))) + ) self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) diff --git a/tests/dataframe/unit/dataframe_sql_validator.py b/tests/dataframe/unit/dataframe_sql_validator.py index fc56553..32ff8f2 100644 --- a/tests/dataframe/unit/dataframe_sql_validator.py +++ b/tests/dataframe/unit/dataframe_sql_validator.py @@ -25,11 +25,17 @@ class DataFrameSQLValidator(unittest.TestCase): (4, "Claire", "Littleton", 27, 2), (5, "Hugo", "Reyes", 29, 100), ] - self.df_employee = self.spark.createDataFrame(data=employee_data, schema=self.employee_schema) + self.df_employee = self.spark.createDataFrame( + data=employee_data, schema=self.employee_schema + ) - def compare_sql(self, df: DataFrame, expected_statements: t.Union[str, t.List[str]], pretty=False): + def compare_sql( + self, df: DataFrame, expected_statements: t.Union[str, t.List[str]], pretty=False + ): actual_sqls = df.sql(pretty=pretty) - expected_statements = [expected_statements] if isinstance(expected_statements, str) else expected_statements + expected_statements = ( + [expected_statements] if isinstance(expected_statements, str) else expected_statements + ) self.assertEqual(len(expected_statements), len(actual_sqls)) for expected, actual in zip(expected_statements, actual_sqls): self.assertEqual(expected, actual) diff --git a/tests/dataframe/unit/test_column.py b/tests/dataframe/unit/test_column.py index 977971e..da18502 100644 --- a/tests/dataframe/unit/test_column.py +++ b/tests/dataframe/unit/test_column.py @@ -26,12 +26,14 @@ class TestDataframeColumn(unittest.TestCase): def test_and(self): self.assertEqual( - "cola = colb AND colc = cold", ((F.col("cola") == F.col("colb")) & (F.col("colc") == F.col("cold"))).sql() + "cola = colb AND colc = cold", + ((F.col("cola") == F.col("colb")) & (F.col("colc") == F.col("cold"))).sql(), ) def test_or(self): self.assertEqual( - "cola = colb OR colc = cold", ((F.col("cola") == F.col("colb")) | (F.col("colc") == F.col("cold"))).sql() + "cola = colb OR colc = cold", + ((F.col("cola") == F.col("colb")) | (F.col("colc") == F.col("cold"))).sql(), ) def test_mod(self): @@ -112,7 +114,9 @@ class TestDataframeColumn(unittest.TestCase): def test_when_otherwise(self): self.assertEqual("CASE WHEN cola = 1 THEN 2 END", F.when(F.col("cola") == 1, 2).sql()) - self.assertEqual("CASE WHEN cola = 1 THEN 2 END", F.col("cola").when(F.col("cola") == 1, 2).sql()) + self.assertEqual( + "CASE WHEN cola = 1 THEN 2 END", F.col("cola").when(F.col("cola") == 1, 2).sql() + ) self.assertEqual( "CASE WHEN cola = 1 THEN 2 WHEN colb = 2 THEN 3 END", (F.when(F.col("cola") == 1, 2).when(F.col("colb") == 2, 3)).sql(), @@ -148,7 +152,9 @@ class TestDataframeColumn(unittest.TestCase): self.assertEqual( "cola BETWEEN CAST('2022-01-01 01:01:01.000000' AS TIMESTAMP) " "AND CAST('2022-03-01 01:01:01.000000' AS TIMESTAMP)", - F.col("cola").between(datetime.datetime(2022, 1, 1, 1, 1, 1), datetime.datetime(2022, 3, 1, 1, 1, 1)).sql(), + F.col("cola") + .between(datetime.datetime(2022, 1, 1, 1, 1, 1), datetime.datetime(2022, 3, 1, 1, 1, 1)) + .sql(), ) def test_over(self): diff --git a/tests/dataframe/unit/test_dataframe.py b/tests/dataframe/unit/test_dataframe.py index c222cac..e36667b 100644 --- a/tests/dataframe/unit/test_dataframe.py +++ b/tests/dataframe/unit/test_dataframe.py @@ -9,7 +9,9 @@ class TestDataframe(DataFrameSQLValidator): self.assertEqual("t17051", DataFrame._create_hash_from_expression(expression)) def test_columns(self): - self.assertEqual(["employee_id", "fname", "lname", "age", "store_id"], self.df_employee.columns) + self.assertEqual( + ["employee_id", "fname", "lname", "age", "store_id"], self.df_employee.columns + ) def test_cache(self): df = self.df_employee.select("fname").cache() diff --git a/tests/dataframe/unit/test_functions.py b/tests/dataframe/unit/test_functions.py index eadbb93..8e5e5cd 100644 --- a/tests/dataframe/unit/test_functions.py +++ b/tests/dataframe/unit/test_functions.py @@ -925,12 +925,17 @@ class TestFunctions(unittest.TestCase): col = SF.window(SF.col("cola"), "10 minutes") self.assertEqual("WINDOW(cola, '10 minutes')", col.sql()) col_all_values = SF.window("cola", "2 minutes 30 seconds", "30 seconds", "15 seconds") - self.assertEqual("WINDOW(cola, '2 minutes 30 seconds', '30 seconds', '15 seconds')", col_all_values.sql()) + self.assertEqual( + "WINDOW(cola, '2 minutes 30 seconds', '30 seconds', '15 seconds')", col_all_values.sql() + ) col_no_start_time = SF.window("cola", "2 minutes 30 seconds", "30 seconds") - self.assertEqual("WINDOW(cola, '2 minutes 30 seconds', '30 seconds')", col_no_start_time.sql()) + self.assertEqual( + "WINDOW(cola, '2 minutes 30 seconds', '30 seconds')", col_no_start_time.sql() + ) col_no_slide = SF.window("cola", "2 minutes 30 seconds", startTime="15 seconds") self.assertEqual( - "WINDOW(cola, '2 minutes 30 seconds', '2 minutes 30 seconds', '15 seconds')", col_no_slide.sql() + "WINDOW(cola, '2 minutes 30 seconds', '2 minutes 30 seconds', '15 seconds')", + col_no_slide.sql(), ) def test_session_window(self): @@ -1359,9 +1364,13 @@ class TestFunctions(unittest.TestCase): def test_from_json(self): col_str = SF.from_json("cola", "cola INT", dict(timestampFormat="dd/MM/yyyy")) - self.assertEqual("FROM_JSON(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql()) + self.assertEqual( + "FROM_JSON(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql() + ) col = SF.from_json(SF.col("cola"), "cola INT", dict(timestampFormat="dd/MM/yyyy")) - self.assertEqual("FROM_JSON(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col.sql()) + self.assertEqual( + "FROM_JSON(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col.sql() + ) col_no_option = SF.from_json("cola", "cola INT") self.assertEqual("FROM_JSON(cola, 'cola INT')", col_no_option.sql()) @@ -1375,7 +1384,9 @@ class TestFunctions(unittest.TestCase): def test_schema_of_json(self): col_str = SF.schema_of_json("cola", dict(timestampFormat="dd/MM/yyyy")) - self.assertEqual("SCHEMA_OF_JSON(cola, MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql()) + self.assertEqual( + "SCHEMA_OF_JSON(cola, MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql() + ) col = SF.schema_of_json(SF.col("cola"), dict(timestampFormat="dd/MM/yyyy")) self.assertEqual("SCHEMA_OF_JSON(cola, MAP('timestampFormat', 'dd/MM/yyyy'))", col.sql()) col_no_option = SF.schema_of_json("cola") @@ -1429,7 +1440,10 @@ class TestFunctions(unittest.TestCase): col = SF.array_sort(SF.col("cola")) self.assertEqual("ARRAY_SORT(cola)", col.sql()) col_comparator = SF.array_sort( - "cola", lambda x, y: SF.when(x.isNull() | y.isNull(), SF.lit(0)).otherwise(SF.length(y) - SF.length(x)) + "cola", + lambda x, y: SF.when(x.isNull() | y.isNull(), SF.lit(0)).otherwise( + SF.length(y) - SF.length(x) + ), ) self.assertEqual( "ARRAY_SORT(cola, (x, y) -> CASE WHEN x IS NULL OR y IS NULL THEN 0 ELSE LENGTH(y) - LENGTH(x) END)", @@ -1504,9 +1518,13 @@ class TestFunctions(unittest.TestCase): def test_from_csv(self): col_str = SF.from_csv("cola", "cola INT", dict(timestampFormat="dd/MM/yyyy")) - self.assertEqual("FROM_CSV(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql()) + self.assertEqual( + "FROM_CSV(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql() + ) col = SF.from_csv(SF.col("cola"), "cola INT", dict(timestampFormat="dd/MM/yyyy")) - self.assertEqual("FROM_CSV(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col.sql()) + self.assertEqual( + "FROM_CSV(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col.sql() + ) col_no_option = SF.from_csv("cola", "cola INT") self.assertEqual("FROM_CSV(cola, 'cola INT')", col_no_option.sql()) @@ -1535,7 +1553,9 @@ class TestFunctions(unittest.TestCase): self.assertEqual("TRANSFORM(cola, (x, i) -> x * i)", col.sql()) col_custom_names = SF.transform("cola", lambda target, row_count: target * row_count) - self.assertEqual("TRANSFORM(cola, (target, row_count) -> target * row_count)", col_custom_names.sql()) + self.assertEqual( + "TRANSFORM(cola, (target, row_count) -> target * row_count)", col_custom_names.sql() + ) def test_exists(self): col_str = SF.exists("cola", lambda x: x % 2 == 0) @@ -1558,10 +1578,13 @@ class TestFunctions(unittest.TestCase): self.assertEqual("FILTER(cola, x -> MONTH(TO_DATE(x)) > 6)", col_str.sql()) col = SF.filter(SF.col("cola"), lambda x, i: SF.month(SF.to_date(x)) > SF.lit(i)) self.assertEqual("FILTER(cola, (x, i) -> MONTH(TO_DATE(x)) > i)", col.sql()) - col_custom_names = SF.filter("cola", lambda target, row_count: SF.month(SF.to_date(target)) > SF.lit(row_count)) + col_custom_names = SF.filter( + "cola", lambda target, row_count: SF.month(SF.to_date(target)) > SF.lit(row_count) + ) self.assertEqual( - "FILTER(cola, (target, row_count) -> MONTH(TO_DATE(target)) > row_count)", col_custom_names.sql() + "FILTER(cola, (target, row_count) -> MONTH(TO_DATE(target)) > row_count)", + col_custom_names.sql(), ) def test_zip_with(self): @@ -1570,7 +1593,9 @@ class TestFunctions(unittest.TestCase): col = SF.zip_with(SF.col("cola"), SF.col("colb"), lambda x, y: SF.concat_ws("_", x, y)) self.assertEqual("ZIP_WITH(cola, colb, (x, y) -> CONCAT_WS('_', x, y))", col.sql()) col_custom_names = SF.zip_with("cola", "colb", lambda l, r: SF.concat_ws("_", l, r)) - self.assertEqual("ZIP_WITH(cola, colb, (l, r) -> CONCAT_WS('_', l, r))", col_custom_names.sql()) + self.assertEqual( + "ZIP_WITH(cola, colb, (l, r) -> CONCAT_WS('_', l, r))", col_custom_names.sql() + ) def test_transform_keys(self): col_str = SF.transform_keys("cola", lambda k, v: SF.upper(k)) @@ -1586,7 +1611,9 @@ class TestFunctions(unittest.TestCase): col = SF.transform_values(SF.col("cola"), lambda k, v: SF.upper(v)) self.assertEqual("TRANSFORM_VALUES(cola, (k, v) -> UPPER(v))", col.sql()) col_custom_names = SF.transform_values("cola", lambda _, value: SF.upper(value)) - self.assertEqual("TRANSFORM_VALUES(cola, (_, value) -> UPPER(value))", col_custom_names.sql()) + self.assertEqual( + "TRANSFORM_VALUES(cola, (_, value) -> UPPER(value))", col_custom_names.sql() + ) def test_map_filter(self): col_str = SF.map_filter("cola", lambda k, v: k > v) diff --git a/tests/dataframe/unit/test_session.py b/tests/dataframe/unit/test_session.py index 158dcec..7e8bfad 100644 --- a/tests/dataframe/unit/test_session.py +++ b/tests/dataframe/unit/test_session.py @@ -21,9 +21,7 @@ class TestDataframeSession(DataFrameSQLValidator): def test_cdf_no_schema(self): df = self.spark.createDataFrame([[1, 2], [3, 4], [None, 6]]) - expected = ( - "SELECT `a2`.`_1` AS `_1`, `a2`.`_2` AS `_2` FROM (VALUES (1, 2), (3, 4), (NULL, 6)) AS `a2`(`_1`, `_2`)" - ) + expected = "SELECT `a2`.`_1` AS `_1`, `a2`.`_2` AS `_2` FROM (VALUES (1, 2), (3, 4), (NULL, 6)) AS `a2`(`_1`, `_2`)" self.compare_sql(df, expected) def test_cdf_row_mixed_primitives(self): @@ -77,7 +75,8 @@ class TestDataframeSession(DataFrameSQLValidator): sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"}) df = self.spark.sql(query) self.assertIn( - "SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`", df.sql(pretty=False) + "SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`", + df.sql(pretty=False), ) @mock.patch("sqlglot.schema", MappingSchema()) @@ -104,9 +103,7 @@ class TestDataframeSession(DataFrameSQLValidator): query = "WITH t1 AS (SELECT cola, colb FROM table) INSERT INTO new_table SELECT cola, colb FROM t1" sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"}) df = self.spark.sql(query) - expected = ( - "INSERT INTO new_table SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`" - ) + expected = "INSERT INTO new_table SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`" self.compare_sql(df, expected) def test_session_create_builder_patterns(self): diff --git a/tests/dataframe/unit/test_types.py b/tests/dataframe/unit/test_types.py index 1f6c5dc..52f5d72 100644 --- a/tests/dataframe/unit/test_types.py +++ b/tests/dataframe/unit/test_types.py @@ -53,7 +53,10 @@ class TestDataframeTypes(unittest.TestCase): self.assertEqual("array<int>", types.ArrayType(types.IntegerType()).simpleString()) def test_map(self): - self.assertEqual("map<int, string>", types.MapType(types.IntegerType(), types.StringType()).simpleString()) + self.assertEqual( + "map<int, string>", + types.MapType(types.IntegerType(), types.StringType()).simpleString(), + ) def test_struct_field(self): self.assertEqual("cola:int", types.StructField("cola", types.IntegerType()).simpleString()) diff --git a/tests/dataframe/unit/test_window.py b/tests/dataframe/unit/test_window.py index eea4582..70a868a 100644 --- a/tests/dataframe/unit/test_window.py +++ b/tests/dataframe/unit/test_window.py @@ -39,22 +39,38 @@ class TestDataframeWindow(unittest.TestCase): def test_window_rows_unbounded(self): rows_between_unbounded_start = Window.rowsBetween(Window.unboundedPreceding, 2) - self.assertEqual("OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)", rows_between_unbounded_start.sql()) + self.assertEqual( + "OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)", + rows_between_unbounded_start.sql(), + ) rows_between_unbounded_end = Window.rowsBetween(1, Window.unboundedFollowing) - self.assertEqual("OVER ( ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)", rows_between_unbounded_end.sql()) - rows_between_unbounded_both = Window.rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing) self.assertEqual( - "OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)", rows_between_unbounded_both.sql() + "OVER ( ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)", + rows_between_unbounded_end.sql(), + ) + rows_between_unbounded_both = Window.rowsBetween( + Window.unboundedPreceding, Window.unboundedFollowing + ) + self.assertEqual( + "OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)", + rows_between_unbounded_both.sql(), ) def test_window_range_unbounded(self): range_between_unbounded_start = Window.rangeBetween(Window.unboundedPreceding, 2) self.assertEqual( - "OVER ( RANGE BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)", range_between_unbounded_start.sql() + "OVER ( RANGE BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)", + range_between_unbounded_start.sql(), ) range_between_unbounded_end = Window.rangeBetween(1, Window.unboundedFollowing) - self.assertEqual("OVER ( RANGE BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)", range_between_unbounded_end.sql()) - range_between_unbounded_both = Window.rangeBetween(Window.unboundedPreceding, Window.unboundedFollowing) self.assertEqual( - "OVER ( RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)", range_between_unbounded_both.sql() + "OVER ( RANGE BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)", + range_between_unbounded_end.sql(), + ) + range_between_unbounded_both = Window.rangeBetween( + Window.unboundedPreceding, Window.unboundedFollowing + ) + self.assertEqual( + "OVER ( RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)", + range_between_unbounded_both.sql(), ) diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py index 050d41e..a0ebc45 100644 --- a/tests/dialects/test_bigquery.py +++ b/tests/dialects/test_bigquery.py @@ -157,6 +157,14 @@ class TestBigQuery(Validator): }, ) + self.validate_all( + "DIV(x, y)", + write={ + "bigquery": "DIV(x, y)", + "duckdb": "CAST(x / y AS INT)", + }, + ) + self.validate_identity( "SELECT ROW() OVER (y ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) FROM x WINDOW y AS (PARTITION BY CATEGORY)" ) @@ -284,4 +292,6 @@ class TestBigQuery(Validator): "CREATE TEMPORARY FUNCTION a(x FLOAT64, y FLOAT64) RETURNS FLOAT64 NOT DETERMINISTIC LANGUAGE js AS 'return x*y;'" ) self.validate_identity("CREATE TEMPORARY FUNCTION a(x FLOAT64, y FLOAT64) AS ((x + 4) / y)") - self.validate_identity("CREATE TABLE FUNCTION a(x INT64) RETURNS TABLE <q STRING, r INT64> AS SELECT s, t") + self.validate_identity( + "CREATE TABLE FUNCTION a(x INT64) RETURNS TABLE <q STRING, r INT64> AS SELECT s, t" + ) diff --git a/tests/dialects/test_clickhouse.py b/tests/dialects/test_clickhouse.py index 715bf10..efb41bb 100644 --- a/tests/dialects/test_clickhouse.py +++ b/tests/dialects/test_clickhouse.py @@ -18,7 +18,6 @@ class TestClickhouse(Validator): "spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname NULLS LAST", }, ) - self.validate_all( "CAST(1 AS NULLABLE(Int64))", write={ @@ -31,3 +30,7 @@ class TestClickhouse(Validator): "clickhouse": "CAST(1 AS Nullable(DateTime64(6, 'UTC')))", }, ) + self.validate_all( + "SELECT x #! comment", + write={"": "SELECT x /* comment */"}, + ) diff --git a/tests/dialects/test_databricks.py b/tests/dialects/test_databricks.py index e242e73..2168f55 100644 --- a/tests/dialects/test_databricks.py +++ b/tests/dialects/test_databricks.py @@ -22,7 +22,8 @@ class TestDatabricks(Validator): }, ) self.validate_all( - "SELECT DATEDIFF('end', 'start')", write={"databricks": "SELECT DATEDIFF(DAY, 'start', 'end')"} + "SELECT DATEDIFF('end', 'start')", + write={"databricks": "SELECT DATEDIFF(DAY, 'start', 'end')"}, ) self.validate_all( "SELECT DATE_ADD('2020-01-01', 1)", diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py index 3b837df..1913f53 100644 --- a/tests/dialects/test_dialect.py +++ b/tests/dialects/test_dialect.py @@ -1,20 +1,18 @@ import unittest -from sqlglot import ( - Dialect, - Dialects, - ErrorLevel, - UnsupportedError, - parse_one, - transpile, -) +from sqlglot import Dialect, Dialects, ErrorLevel, UnsupportedError, parse_one class Validator(unittest.TestCase): dialect = None - def validate_identity(self, sql): - self.assertEqual(transpile(sql, read=self.dialect, write=self.dialect)[0], sql) + def parse_one(self, sql): + return parse_one(sql, read=self.dialect) + + def validate_identity(self, sql, write_sql=None): + expression = self.parse_one(sql) + self.assertEqual(write_sql or sql, expression.sql(dialect=self.dialect)) + return expression def validate_all(self, sql, read=None, write=None, pretty=False): """ @@ -28,12 +26,14 @@ class Validator(unittest.TestCase): read (dict): Mapping of dialect -> SQL write (dict): Mapping of dialect -> SQL """ - expression = parse_one(sql, read=self.dialect) + expression = self.parse_one(sql) for read_dialect, read_sql in (read or {}).items(): with self.subTest(f"{read_dialect} -> {sql}"): self.assertEqual( - parse_one(read_sql, read_dialect).sql(self.dialect, unsupported_level=ErrorLevel.IGNORE), + parse_one(read_sql, read_dialect).sql( + self.dialect, unsupported_level=ErrorLevel.IGNORE, pretty=pretty + ), sql, ) @@ -83,10 +83,6 @@ class TestDialect(Validator): ) self.validate_all( "CAST(a AS BINARY(4))", - read={ - "presto": "CAST(a AS VARBINARY(4))", - "sqlite": "CAST(a AS VARBINARY(4))", - }, write={ "bigquery": "CAST(a AS BINARY(4))", "clickhouse": "CAST(a AS BINARY(4))", @@ -104,6 +100,24 @@ class TestDialect(Validator): }, ) self.validate_all( + "CAST(a AS VARBINARY(4))", + write={ + "bigquery": "CAST(a AS VARBINARY(4))", + "clickhouse": "CAST(a AS VARBINARY(4))", + "duckdb": "CAST(a AS VARBINARY(4))", + "mysql": "CAST(a AS VARBINARY(4))", + "hive": "CAST(a AS BINARY(4))", + "oracle": "CAST(a AS BLOB(4))", + "postgres": "CAST(a AS BYTEA(4))", + "presto": "CAST(a AS VARBINARY(4))", + "redshift": "CAST(a AS VARBYTE(4))", + "snowflake": "CAST(a AS VARBINARY(4))", + "sqlite": "CAST(a AS BLOB(4))", + "spark": "CAST(a AS BINARY(4))", + "starrocks": "CAST(a AS VARBINARY(4))", + }, + ) + self.validate_all( "CAST(MAP('a', '1') AS MAP(TEXT, TEXT))", write={ "clickhouse": "CAST(map('a', '1') AS Map(TEXT, TEXT))", @@ -472,45 +486,57 @@ class TestDialect(Validator): }, ) self.validate_all( - "DATE_TRUNC(x, 'day')", + "DATE_TRUNC('day', x)", write={ "mysql": "DATE(x)", - "starrocks": "DATE(x)", }, ) self.validate_all( - "DATE_TRUNC(x, 'week')", + "DATE_TRUNC('week', x)", write={ "mysql": "STR_TO_DATE(CONCAT(YEAR(x), ' ', WEEK(x, 1), ' 1'), '%Y %u %w')", - "starrocks": "STR_TO_DATE(CONCAT(YEAR(x), ' ', WEEK(x, 1), ' 1'), '%Y %u %w')", }, ) self.validate_all( - "DATE_TRUNC(x, 'month')", + "DATE_TRUNC('month', x)", write={ "mysql": "STR_TO_DATE(CONCAT(YEAR(x), ' ', MONTH(x), ' 1'), '%Y %c %e')", - "starrocks": "STR_TO_DATE(CONCAT(YEAR(x), ' ', MONTH(x), ' 1'), '%Y %c %e')", }, ) self.validate_all( - "DATE_TRUNC(x, 'quarter')", + "DATE_TRUNC('quarter', x)", write={ "mysql": "STR_TO_DATE(CONCAT(YEAR(x), ' ', QUARTER(x) * 3 - 2, ' 1'), '%Y %c %e')", - "starrocks": "STR_TO_DATE(CONCAT(YEAR(x), ' ', QUARTER(x) * 3 - 2, ' 1'), '%Y %c %e')", }, ) self.validate_all( - "DATE_TRUNC(x, 'year')", + "DATE_TRUNC('year', x)", write={ "mysql": "STR_TO_DATE(CONCAT(YEAR(x), ' 1 1'), '%Y %c %e')", - "starrocks": "STR_TO_DATE(CONCAT(YEAR(x), ' 1 1'), '%Y %c %e')", }, ) self.validate_all( - "DATE_TRUNC(x, 'millenium')", + "DATE_TRUNC('millenium', x)", write={ "mysql": UnsupportedError, - "starrocks": UnsupportedError, + }, + ) + self.validate_all( + "DATE_TRUNC('year', x)", + read={ + "starrocks": "DATE_TRUNC('year', x)", + }, + write={ + "starrocks": "DATE_TRUNC('year', x)", + }, + ) + self.validate_all( + "DATE_TRUNC(x, year)", + read={ + "bigquery": "DATE_TRUNC(x, year)", + }, + write={ + "bigquery": "DATE_TRUNC(x, year)", }, ) self.validate_all( @@ -564,6 +590,22 @@ class TestDialect(Validator): "spark": "DATE_ADD(CAST('2020-01-01' AS DATE), 1)", }, ) + self.validate_all( + "TIMESTAMP '2022-01-01'", + write={ + "mysql": "CAST('2022-01-01' AS TIMESTAMP)", + "starrocks": "CAST('2022-01-01' AS DATETIME)", + "hive": "CAST('2022-01-01' AS TIMESTAMP)", + }, + ) + self.validate_all( + "TIMESTAMP('2022-01-01')", + write={ + "mysql": "TIMESTAMP('2022-01-01')", + "starrocks": "TIMESTAMP('2022-01-01')", + "hive": "TIMESTAMP('2022-01-01')", + }, + ) for unit in ("DAY", "MONTH", "YEAR"): self.validate_all( @@ -1002,7 +1044,10 @@ class TestDialect(Validator): ) def test_limit(self): - self.validate_all("SELECT * FROM data LIMIT 10, 20", write={"sqlite": "SELECT * FROM data LIMIT 10 OFFSET 20"}) + self.validate_all( + "SELECT * FROM data LIMIT 10, 20", + write={"sqlite": "SELECT * FROM data LIMIT 10 OFFSET 20"}, + ) self.validate_all( "SELECT x FROM y LIMIT 10", write={ @@ -1132,3 +1177,56 @@ class TestDialect(Validator): "sqlite": "WITH cte1 AS (SELECT a, b FROM table1), cte2 AS (SELECT c, e AS d FROM table2) SELECT b, d AS dd FROM cte1 AS t JOIN cte2 WHERE cte1.a = cte2.c", }, ) + + def test_nullsafe_eq(self): + self.validate_all( + "SELECT a IS NOT DISTINCT FROM b", + read={ + "mysql": "SELECT a <=> b", + "postgres": "SELECT a IS NOT DISTINCT FROM b", + }, + write={ + "mysql": "SELECT a <=> b", + "postgres": "SELECT a IS NOT DISTINCT FROM b", + }, + ) + + def test_nullsafe_neq(self): + self.validate_all( + "SELECT a IS DISTINCT FROM b", + read={ + "postgres": "SELECT a IS DISTINCT FROM b", + }, + write={ + "mysql": "SELECT NOT a <=> b", + "postgres": "SELECT a IS DISTINCT FROM b", + }, + ) + + def test_hash_comments(self): + self.validate_all( + "SELECT 1 /* arbitrary content,,, until end-of-line */", + read={ + "mysql": "SELECT 1 # arbitrary content,,, until end-of-line", + "bigquery": "SELECT 1 # arbitrary content,,, until end-of-line", + "clickhouse": "SELECT 1 #! arbitrary content,,, until end-of-line", + }, + ) + self.validate_all( + """/* comment1 */ +SELECT + x, -- comment2 + y -- comment3""", + read={ + "mysql": """SELECT # comment1 + x, # comment2 + y # comment3""", + "bigquery": """SELECT # comment1 + x, # comment2 + y # comment3""", + "clickhouse": """SELECT # comment1 + x, # comment2 + y # comment3""", + }, + pretty=True, + ) diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py index a25871c..1ba118b 100644 --- a/tests/dialects/test_mysql.py +++ b/tests/dialects/test_mysql.py @@ -1,3 +1,4 @@ +from sqlglot import expressions as exp from tests.dialects.test_dialect import Validator @@ -20,6 +21,52 @@ class TestMySQL(Validator): self.validate_identity("SELECT TRIM(TRAILING 'bla' FROM ' XXX ')") self.validate_identity("SELECT TRIM(BOTH 'bla' FROM ' XXX ')") self.validate_identity("SELECT TRIM('bla' FROM ' XXX ')") + self.validate_identity("@@GLOBAL.max_connections") + + # SET Commands + self.validate_identity("SET @var_name = expr") + self.validate_identity("SET @name = 43") + self.validate_identity("SET @total_tax = (SELECT SUM(tax) FROM taxable_transactions)") + self.validate_identity("SET GLOBAL max_connections = 1000") + self.validate_identity("SET @@GLOBAL.max_connections = 1000") + self.validate_identity("SET SESSION sql_mode = 'TRADITIONAL'") + self.validate_identity("SET LOCAL sql_mode = 'TRADITIONAL'") + self.validate_identity("SET @@SESSION.sql_mode = 'TRADITIONAL'") + self.validate_identity("SET @@LOCAL.sql_mode = 'TRADITIONAL'") + self.validate_identity("SET @@sql_mode = 'TRADITIONAL'") + self.validate_identity("SET sql_mode = 'TRADITIONAL'") + self.validate_identity("SET PERSIST max_connections = 1000") + self.validate_identity("SET @@PERSIST.max_connections = 1000") + self.validate_identity("SET PERSIST_ONLY back_log = 100") + self.validate_identity("SET @@PERSIST_ONLY.back_log = 100") + self.validate_identity("SET @@SESSION.max_join_size = DEFAULT") + self.validate_identity("SET @@SESSION.max_join_size = @@GLOBAL.max_join_size") + self.validate_identity("SET @x = 1, SESSION sql_mode = ''") + self.validate_identity( + "SET GLOBAL sort_buffer_size = 1000000, SESSION sort_buffer_size = 1000000" + ) + self.validate_identity( + "SET @@GLOBAL.sort_buffer_size = 1000000, @@LOCAL.sort_buffer_size = 1000000" + ) + self.validate_identity("SET GLOBAL max_connections = 1000, sort_buffer_size = 1000000") + self.validate_identity("SET @@GLOBAL.sort_buffer_size = 50000, sort_buffer_size = 1000000") + self.validate_identity("SET CHARACTER SET 'utf8'") + self.validate_identity("SET CHARACTER SET utf8") + self.validate_identity("SET CHARACTER SET DEFAULT") + self.validate_identity("SET NAMES 'utf8'") + self.validate_identity("SET NAMES DEFAULT") + self.validate_identity("SET NAMES 'utf8' COLLATE 'utf8_unicode_ci'") + self.validate_identity("SET NAMES utf8 COLLATE utf8_unicode_ci") + self.validate_identity("SET autocommit = ON") + + def test_escape(self): + self.validate_all( + r"'a \' b '' '", + write={ + "mysql": r"'a '' b '' '", + "spark": r"'a \' b \' '", + }, + ) def test_introducers(self): self.validate_all( @@ -115,14 +162,6 @@ class TestMySQL(Validator): }, ) - def test_hash_comments(self): - self.validate_all( - "SELECT 1 # arbitrary content,,, until end-of-line", - write={ - "mysql": "SELECT 1", - }, - ) - def test_mysql(self): self.validate_all( "GROUP_CONCAT(DISTINCT x ORDER BY y DESC)", @@ -174,3 +213,242 @@ COMMENT='客户账户表'""" }, pretty=True, ) + + def test_show_simple(self): + for key, write_key in [ + ("BINARY LOGS", "BINARY LOGS"), + ("MASTER LOGS", "BINARY LOGS"), + ("STORAGE ENGINES", "ENGINES"), + ("ENGINES", "ENGINES"), + ("EVENTS", "EVENTS"), + ("MASTER STATUS", "MASTER STATUS"), + ("PLUGINS", "PLUGINS"), + ("PRIVILEGES", "PRIVILEGES"), + ("PROFILES", "PROFILES"), + ("REPLICAS", "REPLICAS"), + ("SLAVE HOSTS", "REPLICAS"), + ]: + show = self.validate_identity(f"SHOW {key}", f"SHOW {write_key}") + self.assertIsInstance(show, exp.Show) + self.assertEqual(show.name, write_key) + + def test_show_events(self): + for key in ["BINLOG", "RELAYLOG"]: + show = self.validate_identity(f"SHOW {key} EVENTS") + self.assertIsInstance(show, exp.Show) + self.assertEqual(show.name, f"{key} EVENTS") + + show = self.validate_identity(f"SHOW {key} EVENTS IN 'log' FROM 1 LIMIT 2, 3") + self.assertEqual(show.text("log"), "log") + self.assertEqual(show.text("position"), "1") + self.assertEqual(show.text("limit"), "3") + self.assertEqual(show.text("offset"), "2") + + show = self.validate_identity(f"SHOW {key} EVENTS LIMIT 1") + self.assertEqual(show.text("limit"), "1") + self.assertIsNone(show.args.get("offset")) + + def test_show_like_or_where(self): + for key, write_key in [ + ("CHARSET", "CHARACTER SET"), + ("CHARACTER SET", "CHARACTER SET"), + ("COLLATION", "COLLATION"), + ("DATABASES", "DATABASES"), + ("FUNCTION STATUS", "FUNCTION STATUS"), + ("PROCEDURE STATUS", "PROCEDURE STATUS"), + ("GLOBAL STATUS", "GLOBAL STATUS"), + ("SESSION STATUS", "STATUS"), + ("STATUS", "STATUS"), + ("GLOBAL VARIABLES", "GLOBAL VARIABLES"), + ("SESSION VARIABLES", "VARIABLES"), + ("VARIABLES", "VARIABLES"), + ]: + expected_name = write_key.strip("GLOBAL").strip() + template = "SHOW {}" + show = self.validate_identity(template.format(key), template.format(write_key)) + self.assertIsInstance(show, exp.Show) + self.assertEqual(show.name, expected_name) + + template = "SHOW {} LIKE '%foo%'" + show = self.validate_identity(template.format(key), template.format(write_key)) + self.assertIsInstance(show, exp.Show) + self.assertIsInstance(show.args["like"], exp.Literal) + self.assertEqual(show.text("like"), "%foo%") + + template = "SHOW {} WHERE Column_name LIKE '%foo%'" + show = self.validate_identity(template.format(key), template.format(write_key)) + self.assertIsInstance(show, exp.Show) + self.assertIsInstance(show.args["where"], exp.Where) + self.assertEqual(show.args["where"].sql(), "WHERE Column_name LIKE '%foo%'") + + def test_show_columns(self): + show = self.validate_identity("SHOW COLUMNS FROM tbl_name") + self.assertIsInstance(show, exp.Show) + self.assertEqual(show.name, "COLUMNS") + self.assertEqual(show.text("target"), "tbl_name") + self.assertFalse(show.args["full"]) + + show = self.validate_identity("SHOW FULL COLUMNS FROM tbl_name FROM db_name LIKE '%foo%'") + self.assertIsInstance(show, exp.Show) + self.assertEqual(show.text("target"), "tbl_name") + self.assertTrue(show.args["full"]) + self.assertEqual(show.text("db"), "db_name") + self.assertIsInstance(show.args["like"], exp.Literal) + self.assertEqual(show.text("like"), "%foo%") + + def test_show_name(self): + for key in [ + "CREATE DATABASE", + "CREATE EVENT", + "CREATE FUNCTION", + "CREATE PROCEDURE", + "CREATE TABLE", + "CREATE TRIGGER", + "CREATE VIEW", + "FUNCTION CODE", + "PROCEDURE CODE", + ]: + show = self.validate_identity(f"SHOW {key} foo") + self.assertIsInstance(show, exp.Show) + self.assertEqual(show.name, key) + self.assertEqual(show.text("target"), "foo") + + def test_show_grants(self): + show = self.validate_identity(f"SHOW GRANTS FOR foo") + self.assertIsInstance(show, exp.Show) + self.assertEqual(show.name, "GRANTS") + self.assertEqual(show.text("target"), "foo") + + def test_show_engine(self): + show = self.validate_identity("SHOW ENGINE foo STATUS") + self.assertIsInstance(show, exp.Show) + self.assertEqual(show.name, "ENGINE") + self.assertEqual(show.text("target"), "foo") + self.assertFalse(show.args["mutex"]) + + show = self.validate_identity("SHOW ENGINE foo MUTEX") + self.assertEqual(show.name, "ENGINE") + self.assertEqual(show.text("target"), "foo") + self.assertTrue(show.args["mutex"]) + + def test_show_errors(self): + for key in ["ERRORS", "WARNINGS"]: + show = self.validate_identity(f"SHOW {key}") + self.assertIsInstance(show, exp.Show) + self.assertEqual(show.name, key) + + show = self.validate_identity(f"SHOW {key} LIMIT 2, 3") + self.assertEqual(show.text("limit"), "3") + self.assertEqual(show.text("offset"), "2") + + def test_show_index(self): + show = self.validate_identity("SHOW INDEX FROM foo") + self.assertIsInstance(show, exp.Show) + self.assertEqual(show.name, "INDEX") + self.assertEqual(show.text("target"), "foo") + + show = self.validate_identity("SHOW INDEX FROM foo FROM bar") + self.assertEqual(show.text("db"), "bar") + + def test_show_db_like_or_where_sql(self): + for key in [ + "OPEN TABLES", + "TABLE STATUS", + "TRIGGERS", + ]: + show = self.validate_identity(f"SHOW {key}") + self.assertIsInstance(show, exp.Show) + self.assertEqual(show.name, key) + + show = self.validate_identity(f"SHOW {key} FROM db_name") + self.assertEqual(show.name, key) + self.assertEqual(show.text("db"), "db_name") + + show = self.validate_identity(f"SHOW {key} LIKE '%foo%'") + self.assertEqual(show.name, key) + self.assertIsInstance(show.args["like"], exp.Literal) + self.assertEqual(show.text("like"), "%foo%") + + show = self.validate_identity(f"SHOW {key} WHERE Column_name LIKE '%foo%'") + self.assertEqual(show.name, key) + self.assertIsInstance(show.args["where"], exp.Where) + self.assertEqual(show.args["where"].sql(), "WHERE Column_name LIKE '%foo%'") + + def test_show_processlist(self): + show = self.validate_identity("SHOW PROCESSLIST") + self.assertIsInstance(show, exp.Show) + self.assertEqual(show.name, "PROCESSLIST") + self.assertFalse(show.args["full"]) + + show = self.validate_identity("SHOW FULL PROCESSLIST") + self.assertEqual(show.name, "PROCESSLIST") + self.assertTrue(show.args["full"]) + + def test_show_profile(self): + show = self.validate_identity("SHOW PROFILE") + self.assertIsInstance(show, exp.Show) + self.assertEqual(show.name, "PROFILE") + + show = self.validate_identity("SHOW PROFILE BLOCK IO") + self.assertEqual(show.args["types"][0].name, "BLOCK IO") + + show = self.validate_identity( + "SHOW PROFILE BLOCK IO, PAGE FAULTS FOR QUERY 1 OFFSET 2 LIMIT 3" + ) + self.assertEqual(show.args["types"][0].name, "BLOCK IO") + self.assertEqual(show.args["types"][1].name, "PAGE FAULTS") + self.assertEqual(show.text("query"), "1") + self.assertEqual(show.text("offset"), "2") + self.assertEqual(show.text("limit"), "3") + + def test_show_replica_status(self): + show = self.validate_identity("SHOW REPLICA STATUS") + self.assertIsInstance(show, exp.Show) + self.assertEqual(show.name, "REPLICA STATUS") + + show = self.validate_identity("SHOW SLAVE STATUS", "SHOW REPLICA STATUS") + self.assertIsInstance(show, exp.Show) + self.assertEqual(show.name, "REPLICA STATUS") + + show = self.validate_identity("SHOW REPLICA STATUS FOR CHANNEL channel_name") + self.assertEqual(show.text("channel"), "channel_name") + + def test_show_tables(self): + show = self.validate_identity("SHOW TABLES") + self.assertIsInstance(show, exp.Show) + self.assertEqual(show.name, "TABLES") + + show = self.validate_identity("SHOW FULL TABLES FROM db_name LIKE '%foo%'") + self.assertTrue(show.args["full"]) + self.assertEqual(show.text("db"), "db_name") + self.assertIsInstance(show.args["like"], exp.Literal) + self.assertEqual(show.text("like"), "%foo%") + + def test_set_variable(self): + cmd = self.parse_one("SET SESSION x = 1") + item = cmd.expressions[0] + self.assertEqual(item.text("kind"), "SESSION") + self.assertIsInstance(item.this, exp.EQ) + self.assertEqual(item.this.left.name, "x") + self.assertEqual(item.this.right.name, "1") + + cmd = self.parse_one("SET @@GLOBAL.x = @@GLOBAL.y") + item = cmd.expressions[0] + self.assertEqual(item.text("kind"), "") + self.assertIsInstance(item.this, exp.EQ) + self.assertIsInstance(item.this.left, exp.SessionParameter) + self.assertIsInstance(item.this.right, exp.SessionParameter) + + cmd = self.parse_one("SET NAMES 'charset_name' COLLATE 'collation_name'") + item = cmd.expressions[0] + self.assertEqual(item.text("kind"), "NAMES") + self.assertEqual(item.name, "charset_name") + self.assertEqual(item.text("collate"), "collation_name") + + cmd = self.parse_one("SET CHARSET DEFAULT") + item = cmd.expressions[0] + self.assertEqual(item.text("kind"), "CHARACTER SET") + self.assertEqual(item.this.name, "DEFAULT") + + cmd = self.parse_one("SET x = 1, y = 2") + self.assertEqual(len(cmd.expressions), 2) diff --git a/tests/dialects/test_postgres.py b/tests/dialects/test_postgres.py index 35141e2..8294eea 100644 --- a/tests/dialects/test_postgres.py +++ b/tests/dialects/test_postgres.py @@ -8,7 +8,9 @@ class TestPostgres(Validator): def test_ddl(self): self.validate_all( "CREATE TABLE products (product_no INT UNIQUE, name TEXT, price DECIMAL)", - write={"postgres": "CREATE TABLE products (product_no INT UNIQUE, name TEXT, price DECIMAL)"}, + write={ + "postgres": "CREATE TABLE products (product_no INT UNIQUE, name TEXT, price DECIMAL)" + }, ) self.validate_all( "CREATE TABLE products (product_no INT CONSTRAINT must_be_different UNIQUE, name TEXT CONSTRAINT present NOT NULL, price DECIMAL)", @@ -59,15 +61,27 @@ class TestPostgres(Validator): def test_postgres(self): self.validate_identity("SELECT CASE WHEN SUBSTRING('abcdefg') IN ('ab') THEN 1 ELSE 0 END") - self.validate_identity("SELECT CASE WHEN SUBSTRING('abcdefg' FROM 1) IN ('ab') THEN 1 ELSE 0 END") - self.validate_identity("SELECT CASE WHEN SUBSTRING('abcdefg' FROM 1 FOR 2) IN ('ab') THEN 1 ELSE 0 END") - self.validate_identity('SELECT * FROM "x" WHERE SUBSTRING("x"."foo" FROM 1 FOR 2) IN (\'mas\')') + self.validate_identity( + "SELECT CASE WHEN SUBSTRING('abcdefg' FROM 1) IN ('ab') THEN 1 ELSE 0 END" + ) + self.validate_identity( + "SELECT CASE WHEN SUBSTRING('abcdefg' FROM 1 FOR 2) IN ('ab') THEN 1 ELSE 0 END" + ) + self.validate_identity( + 'SELECT * FROM "x" WHERE SUBSTRING("x"."foo" FROM 1 FOR 2) IN (\'mas\')' + ) self.validate_identity("SELECT * FROM x WHERE SUBSTRING('Thomas' FROM '...$') IN ('mas')") - self.validate_identity("SELECT * FROM x WHERE SUBSTRING('Thomas' FROM '%#\"o_a#\"_' FOR '#') IN ('mas')") - self.validate_identity("SELECT SUBSTRING('bla' + 'foo' || 'bar' FROM 3 - 1 + 5 FOR 4 + SOME_FUNC(arg1, arg2))") + self.validate_identity( + "SELECT * FROM x WHERE SUBSTRING('Thomas' FROM '%#\"o_a#\"_' FOR '#') IN ('mas')" + ) + self.validate_identity( + "SELECT SUBSTRING('bla' + 'foo' || 'bar' FROM 3 - 1 + 5 FOR 4 + SOME_FUNC(arg1, arg2))" + ) self.validate_identity("SELECT TRIM(' X' FROM ' XXX ')") self.validate_identity("SELECT TRIM(LEADING 'bla' FROM ' XXX ' COLLATE utf8_bin)") - self.validate_identity("SELECT TO_TIMESTAMP(1284352323.5), TO_TIMESTAMP('05 Dec 2000', 'DD Mon YYYY')") + self.validate_identity( + "SELECT TO_TIMESTAMP(1284352323.5), TO_TIMESTAMP('05 Dec 2000', 'DD Mon YYYY')" + ) self.validate_identity("COMMENT ON TABLE mytable IS 'this'") self.validate_identity("SELECT e'\\xDEADBEEF'") self.validate_identity("SELECT CAST(e'\\176' AS BYTEA)") @@ -75,7 +89,7 @@ class TestPostgres(Validator): self.validate_all( "CREATE TABLE x (a UUID, b BYTEA)", write={ - "duckdb": "CREATE TABLE x (a UUID, b BINARY)", + "duckdb": "CREATE TABLE x (a UUID, b VARBINARY)", "presto": "CREATE TABLE x (a UUID, b VARBINARY)", "hive": "CREATE TABLE x (a UUID, b BINARY)", "spark": "CREATE TABLE x (a UUID, b BINARY)", @@ -153,7 +167,9 @@ class TestPostgres(Validator): ) self.validate_all( "SELECT * FROM foo, LATERAL (SELECT * FROM bar WHERE bar.id = foo.bar_id) AS ss", - read={"postgres": "SELECT * FROM foo, LATERAL (SELECT * FROM bar WHERE bar.id = foo.bar_id) AS ss"}, + read={ + "postgres": "SELECT * FROM foo, LATERAL (SELECT * FROM bar WHERE bar.id = foo.bar_id) AS ss" + }, ) self.validate_all( "SELECT m.name FROM manufacturers AS m LEFT JOIN LATERAL GET_PRODUCT_NAMES(m.id) AS pname ON TRUE WHERE pname IS NULL", @@ -169,11 +185,15 @@ class TestPostgres(Validator): ) self.validate_all( "SELECT id, email, CAST(deleted AS TEXT) FROM users WHERE NOT deleted IS NULL", - read={"postgres": "SELECT id, email, CAST(deleted AS TEXT) FROM users WHERE deleted NOTNULL"}, + read={ + "postgres": "SELECT id, email, CAST(deleted AS TEXT) FROM users WHERE deleted NOTNULL" + }, ) self.validate_all( "SELECT id, email, CAST(deleted AS TEXT) FROM users WHERE NOT deleted IS NULL", - read={"postgres": "SELECT id, email, CAST(deleted AS TEXT) FROM users WHERE NOT deleted ISNULL"}, + read={ + "postgres": "SELECT id, email, CAST(deleted AS TEXT) FROM users WHERE NOT deleted ISNULL" + }, ) self.validate_all( "'[1,2,3]'::json->2", @@ -184,7 +204,8 @@ class TestPostgres(Validator): write={"postgres": """CAST('{"a":1,"b":2}' AS JSON)->'b'"""}, ) self.validate_all( - """'{"x": {"y": 1}}'::json->'x'->'y'""", write={"postgres": """CAST('{"x": {"y": 1}}' AS JSON)->'x'->'y'"""} + """'{"x": {"y": 1}}'::json->'x'->'y'""", + write={"postgres": """CAST('{"x": {"y": 1}}' AS JSON)->'x'->'y'"""}, ) self.validate_all( """'{"x": {"y": 1}}'::json->'x'::json->'y'""", diff --git a/tests/dialects/test_redshift.py b/tests/dialects/test_redshift.py index 1ed2bb6..5309a34 100644 --- a/tests/dialects/test_redshift.py +++ b/tests/dialects/test_redshift.py @@ -61,4 +61,6 @@ class TestRedshift(Validator): "SELECT caldate + INTERVAL '1 second' AS dateplus FROM date WHERE caldate = '12-31-2008'" ) self.validate_identity("CREATE TABLE datetable (start_date DATE, end_date DATE)") - self.validate_identity("SELECT COUNT(*) FROM event WHERE eventname LIKE '%Ring%' OR eventname LIKE '%Die%'") + self.validate_identity( + "SELECT COUNT(*) FROM event WHERE eventname LIKE '%Ring%' OR eventname LIKE '%Die%'" + ) diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py index fea2311..1846b17 100644 --- a/tests/dialects/test_snowflake.py +++ b/tests/dialects/test_snowflake.py @@ -336,7 +336,8 @@ class TestSnowflake(Validator): def test_table_literal(self): # All examples from https://docs.snowflake.com/en/sql-reference/literals-table.html self.validate_all( - r"""SELECT * FROM TABLE('MYTABLE')""", write={"snowflake": r"""SELECT * FROM TABLE('MYTABLE')"""} + r"""SELECT * FROM TABLE('MYTABLE')""", + write={"snowflake": r"""SELECT * FROM TABLE('MYTABLE')"""}, ) self.validate_all( @@ -352,15 +353,123 @@ class TestSnowflake(Validator): write={"snowflake": r"""SELECT * FROM TABLE('MYDB. "MYSCHEMA"."MYTABLE"')"""}, ) - self.validate_all(r"""SELECT * FROM TABLE($MYVAR)""", write={"snowflake": r"""SELECT * FROM TABLE($MYVAR)"""}) + self.validate_all( + r"""SELECT * FROM TABLE($MYVAR)""", + write={"snowflake": r"""SELECT * FROM TABLE($MYVAR)"""}, + ) - self.validate_all(r"""SELECT * FROM TABLE(?)""", write={"snowflake": r"""SELECT * FROM TABLE(?)"""}) + self.validate_all( + r"""SELECT * FROM TABLE(?)""", write={"snowflake": r"""SELECT * FROM TABLE(?)"""} + ) self.validate_all( - r"""SELECT * FROM TABLE(:BINDING)""", write={"snowflake": r"""SELECT * FROM TABLE(:BINDING)"""} + r"""SELECT * FROM TABLE(:BINDING)""", + write={"snowflake": r"""SELECT * FROM TABLE(:BINDING)"""}, ) self.validate_all( r"""SELECT * FROM TABLE($MYVAR) WHERE COL1 = 10""", write={"snowflake": r"""SELECT * FROM TABLE($MYVAR) WHERE COL1 = 10"""}, ) + + def test_flatten(self): + self.validate_all( + """ + select + dag_report.acct_id, + dag_report.report_date, + dag_report.report_uuid, + dag_report.airflow_name, + dag_report.dag_id, + f.value::varchar as operator + from cs.telescope.dag_report, + table(flatten(input=>split(operators, ','))) f + """, + write={ + "snowflake": """SELECT + dag_report.acct_id, + dag_report.report_date, + dag_report.report_uuid, + dag_report.airflow_name, + dag_report.dag_id, + CAST(f.value AS VARCHAR) AS operator +FROM cs.telescope.dag_report, TABLE(FLATTEN(input => SPLIT(operators, ','))) AS f""" + }, + pretty=True, + ) + + # All examples from https://docs.snowflake.com/en/sql-reference/functions/flatten.html#syntax + self.validate_all( + "SELECT * FROM TABLE(FLATTEN(input => parse_json('[1, ,77]'))) f", + write={ + "snowflake": "SELECT * FROM TABLE(FLATTEN(input => PARSE_JSON('[1, ,77]'))) AS f" + }, + ) + + self.validate_all( + """SELECT * FROM TABLE(FLATTEN(input => parse_json('{"a":1, "b":[77,88]}'), outer => true)) f""", + write={ + "snowflake": """SELECT * FROM TABLE(FLATTEN(input => PARSE_JSON('{"a":1, "b":[77,88]}'), outer => TRUE)) AS f""" + }, + ) + + self.validate_all( + """SELECT * FROM TABLE(FLATTEN(input => parse_json('{"a":1, "b":[77,88]}'), path => 'b')) f""", + write={ + "snowflake": """SELECT * FROM TABLE(FLATTEN(input => PARSE_JSON('{"a":1, "b":[77,88]}'), path => 'b')) AS f""" + }, + ) + + self.validate_all( + """SELECT * FROM TABLE(FLATTEN(input => parse_json('[]'))) f""", + write={"snowflake": """SELECT * FROM TABLE(FLATTEN(input => PARSE_JSON('[]'))) AS f"""}, + ) + + self.validate_all( + """SELECT * FROM TABLE(FLATTEN(input => parse_json('[]'), outer => true)) f""", + write={ + "snowflake": """SELECT * FROM TABLE(FLATTEN(input => PARSE_JSON('[]'), outer => TRUE)) AS f""" + }, + ) + + self.validate_all( + """SELECT * FROM TABLE(FLATTEN(input => parse_json('{"a":1, "b":[77,88], "c": {"d":"X"}}'))) f""", + write={ + "snowflake": """SELECT * FROM TABLE(FLATTEN(input => PARSE_JSON('{"a":1, "b":[77,88], "c": {"d":"X"}}'))) AS f""" + }, + ) + + self.validate_all( + """SELECT * FROM TABLE(FLATTEN(input => parse_json('{"a":1, "b":[77,88], "c": {"d":"X"}}'), recursive => true)) f""", + write={ + "snowflake": """SELECT * FROM TABLE(FLATTEN(input => PARSE_JSON('{"a":1, "b":[77,88], "c": {"d":"X"}}'), recursive => TRUE)) AS f""" + }, + ) + + self.validate_all( + """SELECT * FROM TABLE(FLATTEN(input => parse_json('{"a":1, "b":[77,88], "c": {"d":"X"}}'), recursive => true, mode => 'object')) f""", + write={ + "snowflake": """SELECT * FROM TABLE(FLATTEN(input => PARSE_JSON('{"a":1, "b":[77,88], "c": {"d":"X"}}'), recursive => TRUE, mode => 'object')) AS f""" + }, + ) + + self.validate_all( + """ + SELECT id as "ID", + f.value AS "Contact", + f1.value:type AS "Type", + f1.value:content AS "Details" + FROM persons p, + lateral flatten(input => p.c, path => 'contact') f, + lateral flatten(input => f.value:business) f1 + """, + write={ + "snowflake": """SELECT + id AS "ID", + f.value AS "Contact", + f1.value['type'] AS "Type", + f1.value['content'] AS "Details" +FROM persons AS p, LATERAL FLATTEN(input => p.c, path => 'contact') f, LATERAL FLATTEN(input => f.value['business']) f1""", + }, + pretty=True, + ) diff --git a/tests/dialects/test_spark.py b/tests/dialects/test_spark.py index 8605bd1..4470722 100644 --- a/tests/dialects/test_spark.py +++ b/tests/dialects/test_spark.py @@ -284,4 +284,6 @@ TBLPROPERTIES ( ) def test_iif(self): - self.validate_all("SELECT IIF(cond, 'True', 'False')", write={"spark": "SELECT IF(cond, 'True', 'False')"}) + self.validate_all( + "SELECT IIF(cond, 'True', 'False')", write={"spark": "SELECT IF(cond, 'True', 'False')"} + ) diff --git a/tests/dialects/test_starrocks.py b/tests/dialects/test_starrocks.py index 1fe1a57..35d8b45 100644 --- a/tests/dialects/test_starrocks.py +++ b/tests/dialects/test_starrocks.py @@ -6,3 +6,6 @@ class TestMySQL(Validator): def test_identity(self): self.validate_identity("SELECT CAST(`a`.`b` AS INT) FROM foo") + + def test_time(self): + self.validate_identity("TIMESTAMP('2022-01-01')") diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py index d22a9c2..a60f48d 100644 --- a/tests/dialects/test_tsql.py +++ b/tests/dialects/test_tsql.py @@ -278,12 +278,19 @@ class TestTSQL(Validator): def test_add_date(self): self.validate_identity("SELECT DATEADD(year, 1, '2017/08/25')") self.validate_all( - "SELECT DATEADD(year, 1, '2017/08/25')", write={"spark": "SELECT ADD_MONTHS('2017/08/25', 12)"} + "SELECT DATEADD(year, 1, '2017/08/25')", + write={"spark": "SELECT ADD_MONTHS('2017/08/25', 12)"}, + ) + self.validate_all( + "SELECT DATEADD(qq, 1, '2017/08/25')", + write={"spark": "SELECT ADD_MONTHS('2017/08/25', 3)"}, ) - self.validate_all("SELECT DATEADD(qq, 1, '2017/08/25')", write={"spark": "SELECT ADD_MONTHS('2017/08/25', 3)"}) self.validate_all( "SELECT DATEADD(wk, 1, '2017/08/25')", - write={"spark": "SELECT DATE_ADD('2017/08/25', 7)", "databricks": "SELECT DATEADD(week, 1, '2017/08/25')"}, + write={ + "spark": "SELECT DATE_ADD('2017/08/25', 7)", + "databricks": "SELECT DATEADD(week, 1, '2017/08/25')", + }, ) def test_date_diff(self): @@ -370,13 +377,21 @@ class TestTSQL(Validator): "SELECT FORMAT(1000000.01,'###,###.###')", write={"spark": "SELECT FORMAT_NUMBER(1000000.01, '###,###.###')"}, ) - self.validate_all("SELECT FORMAT(1234567, 'f')", write={"spark": "SELECT FORMAT_NUMBER(1234567, 'f')"}) + self.validate_all( + "SELECT FORMAT(1234567, 'f')", write={"spark": "SELECT FORMAT_NUMBER(1234567, 'f')"} + ) self.validate_all( "SELECT FORMAT('01-01-1991', 'dd.mm.yyyy')", write={"spark": "SELECT DATE_FORMAT('01-01-1991', 'dd.mm.yyyy')"}, ) self.validate_all( - "SELECT FORMAT(date_col, 'dd.mm.yyyy')", write={"spark": "SELECT DATE_FORMAT(date_col, 'dd.mm.yyyy')"} + "SELECT FORMAT(date_col, 'dd.mm.yyyy')", + write={"spark": "SELECT DATE_FORMAT(date_col, 'dd.mm.yyyy')"}, + ) + self.validate_all( + "SELECT FORMAT(date_col, 'm')", + write={"spark": "SELECT DATE_FORMAT(date_col, 'MMMM d')"}, + ) + self.validate_all( + "SELECT FORMAT(num_col, 'c')", write={"spark": "SELECT FORMAT_NUMBER(num_col, 'c')"} ) - self.validate_all("SELECT FORMAT(date_col, 'm')", write={"spark": "SELECT DATE_FORMAT(date_col, 'MMMM d')"}) - self.validate_all("SELECT FORMAT(num_col, 'c')", write={"spark": "SELECT FORMAT_NUMBER(num_col, 'c')"}) diff --git a/tests/fixtures/identity.sql b/tests/fixtures/identity.sql index d7084ac..836ab28 100644 --- a/tests/fixtures/identity.sql +++ b/tests/fixtures/identity.sql @@ -523,6 +523,8 @@ DROP VIEW a.b DROP VIEW IF EXISTS a DROP VIEW IF EXISTS a.b SHOW TABLES +USE db +ROLLBACK EXPLAIN SELECT * FROM x INSERT INTO x SELECT * FROM y INSERT INTO x (SELECT * FROM y) @@ -569,3 +571,13 @@ SELECT * FROM (tbl1 LEFT JOIN tbl2 ON 1 = 1) SELECT * FROM (tbl1 JOIN tbl2 JOIN tbl3) SELECT * FROM (tbl1 JOIN (tbl2 JOIN tbl3) ON bla = foo) SELECT * FROM (tbl1 JOIN LATERAL (SELECT * FROM bla) AS tbl) +SELECT CAST(x AS INT) /* comment */ FROM foo +SELECT a /* x */, b /* x */ +SELECT * FROM foo /* x */, bla /* x */ +SELECT 1 /* comment */ + 1 +SELECT 1 /* c1 */ + 2 /* c2 */ +SELECT 1 /* c1 */ + 2 /* c2 */ + 3 /* c3 */ +SELECT 1 /* c1 */ + 2 /* c2 */, 3 /* c3 */ +SELECT x FROM a.b.c /* x */, e.f.g /* x */ +SELECT FOO(x /* c */) /* FOO */, b /* b */ +SELECT FOO(x /* c1 */ + y /* c2 */ + BLA(5 /* c3 */)) FROM VALUES (1 /* c4 */, "test" /* c5 */) /* c6 */ diff --git a/tests/fixtures/optimizer/qualify_columns.sql b/tests/fixtures/optimizer/qualify_columns.sql index a958c08..1176078 100644 --- a/tests/fixtures/optimizer/qualify_columns.sql +++ b/tests/fixtures/optimizer/qualify_columns.sql @@ -104,6 +104,16 @@ SELECT x.b AS b, x.a AS a FROM x AS x LEFT JOIN y AS y ON x.b = y.b QUALIFY ROW_ SELECT AGGREGATE(ARRAY(a, x.b), 0, (x, acc) -> x + acc + a) AS sum_agg FROM x; SELECT AGGREGATE(ARRAY(x.a, x.b), 0, (x, acc) -> x + acc + x.a) AS sum_agg FROM x AS x; +# dialect: starrocks +# execute: false +SELECT DATE_TRUNC('week', a) AS a FROM x; +SELECT DATE_TRUNC('week', x.a) AS a FROM x AS x; + +# dialect: bigquery +# execute: false +SELECT DATE_TRUNC(a, MONTH) AS a FROM x; +SELECT DATE_TRUNC(x.a, MONTH) AS a FROM x AS x; + -------------------------------------- -- Derived tables -------------------------------------- diff --git a/tests/fixtures/optimizer/simplify.sql b/tests/fixtures/optimizer/simplify.sql index 07e818f..7207ba2 100644 --- a/tests/fixtures/optimizer/simplify.sql +++ b/tests/fixtures/optimizer/simplify.sql @@ -79,6 +79,15 @@ NULL; NULL = NULL; NULL; +NULL <=> NULL; +TRUE; + +a IS NOT DISTINCT FROM a; +TRUE; + +NULL IS DISTINCT FROM NULL; +FALSE; + NOT (NOT TRUE); TRUE; diff --git a/tests/fixtures/pretty.sql b/tests/fixtures/pretty.sql index 2570650..5e27b5e 100644 --- a/tests/fixtures/pretty.sql +++ b/tests/fixtures/pretty.sql @@ -287,3 +287,31 @@ SELECT "fffffff" ) ); +/* + multi + line + comment +*/ +SELECT * FROM foo; +/* + multi + line + comment +*/ +SELECT + * +FROM foo; +SELECT x FROM a.b.c /*x*/, e.f.g /*x*/; +SELECT + x +FROM a.b.c /* x */, e.f.g /* x */; +SELECT x FROM (SELECT * FROM bla /*x*/WHERE id = 1) /*x*/; +SELECT + x +FROM ( + SELECT + * + FROM bla /* x */ + WHERE + id = 1 +) /* x */; diff --git a/tests/test_build.py b/tests/test_build.py index b7b6865..721c868 100644 --- a/tests/test_build.py +++ b/tests/test_build.py @@ -100,15 +100,21 @@ class TestBuild(unittest.TestCase): "SELECT x FROM tbl LEFT OUTER JOIN tbl2", ), ( - lambda: select("x").from_("tbl").join(exp.Table(this="tbl2"), join_type="left outer"), + lambda: select("x") + .from_("tbl") + .join(exp.Table(this="tbl2"), join_type="left outer"), "SELECT x FROM tbl LEFT OUTER JOIN tbl2", ), ( - lambda: select("x").from_("tbl").join(exp.Table(this="tbl2"), join_type="left outer", join_alias="foo"), + lambda: select("x") + .from_("tbl") + .join(exp.Table(this="tbl2"), join_type="left outer", join_alias="foo"), "SELECT x FROM tbl LEFT OUTER JOIN tbl2 AS foo", ), ( - lambda: select("x").from_("tbl").join(select("y").from_("tbl2"), join_type="left outer"), + lambda: select("x") + .from_("tbl") + .join(select("y").from_("tbl2"), join_type="left outer"), "SELECT x FROM tbl LEFT OUTER JOIN (SELECT y FROM tbl2)", ), ( @@ -131,7 +137,9 @@ class TestBuild(unittest.TestCase): "SELECT x FROM tbl LEFT OUTER JOIN (SELECT y FROM tbl2) AS aliased", ), ( - lambda: select("x").from_("tbl").join(parse_one("left join x", into=exp.Join), on="a=b"), + lambda: select("x") + .from_("tbl") + .join(parse_one("left join x", into=exp.Join), on="a=b"), "SELECT x FROM tbl LEFT JOIN x ON a = b", ), ( @@ -139,7 +147,9 @@ class TestBuild(unittest.TestCase): "SELECT x FROM tbl LEFT JOIN x ON a = b", ), ( - lambda: select("x").from_("tbl").join("select b from tbl2", on="a=b", join_type="left"), + lambda: select("x") + .from_("tbl") + .join("select b from tbl2", on="a=b", join_type="left"), "SELECT x FROM tbl LEFT JOIN (SELECT b FROM tbl2) ON a = b", ), ( @@ -162,7 +172,10 @@ class TestBuild(unittest.TestCase): ( lambda: select("x", "y", "z") .from_("merged_df") - .join("vte_diagnosis_df", using=[exp.to_identifier("patient_id"), exp.to_identifier("encounter_id")]), + .join( + "vte_diagnosis_df", + using=[exp.to_identifier("patient_id"), exp.to_identifier("encounter_id")], + ), "SELECT x, y, z FROM merged_df JOIN vte_diagnosis_df USING (patient_id, encounter_id)", ), ( @@ -222,7 +235,10 @@ class TestBuild(unittest.TestCase): "SELECT x, y, z, a FROM tbl ORDER BY x, y, z, a", ), ( - lambda: select("x", "y", "z", "a").from_("tbl").cluster_by("x, y", "z").cluster_by("a"), + lambda: select("x", "y", "z", "a") + .from_("tbl") + .cluster_by("x, y", "z") + .cluster_by("a"), "SELECT x, y, z, a FROM tbl CLUSTER BY x, y, z, a", ), ( @@ -239,7 +255,9 @@ class TestBuild(unittest.TestCase): "WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl", ), ( - lambda: select("x").from_("tbl").with_("tbl", as_="SELECT x FROM tbl2", recursive=True), + lambda: select("x") + .from_("tbl") + .with_("tbl", as_="SELECT x FROM tbl2", recursive=True), "WITH RECURSIVE tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl", ), ( @@ -247,7 +265,9 @@ class TestBuild(unittest.TestCase): "WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl", ), ( - lambda: select("x").from_("tbl").with_("tbl (x, y)", as_=select("x", "y").from_("tbl2")), + lambda: select("x") + .from_("tbl") + .with_("tbl (x, y)", as_=select("x", "y").from_("tbl2")), "WITH tbl(x, y) AS (SELECT x, y FROM tbl2) SELECT x FROM tbl", ), ( @@ -258,7 +278,10 @@ class TestBuild(unittest.TestCase): "WITH tbl AS (SELECT x FROM tbl2), tbl2 AS (SELECT x FROM tbl3) SELECT x FROM tbl", ), ( - lambda: select("x").from_("tbl").with_("tbl", as_=select("x", "y").from_("tbl2")).select("y"), + lambda: select("x") + .from_("tbl") + .with_("tbl", as_=select("x", "y").from_("tbl2")) + .select("y"), "WITH tbl AS (SELECT x, y FROM tbl2) SELECT x, y FROM tbl", ), ( @@ -266,35 +289,59 @@ class TestBuild(unittest.TestCase): "WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl", ), ( - lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").group_by("x"), + lambda: select("x") + .with_("tbl", as_=select("x").from_("tbl2")) + .from_("tbl") + .group_by("x"), "WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl GROUP BY x", ), ( - lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").order_by("x"), + lambda: select("x") + .with_("tbl", as_=select("x").from_("tbl2")) + .from_("tbl") + .order_by("x"), "WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl ORDER BY x", ), ( - lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").limit(10), + lambda: select("x") + .with_("tbl", as_=select("x").from_("tbl2")) + .from_("tbl") + .limit(10), "WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl LIMIT 10", ), ( - lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").offset(10), + lambda: select("x") + .with_("tbl", as_=select("x").from_("tbl2")) + .from_("tbl") + .offset(10), "WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl OFFSET 10", ), ( - lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").join("tbl3"), + lambda: select("x") + .with_("tbl", as_=select("x").from_("tbl2")) + .from_("tbl") + .join("tbl3"), "WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl JOIN tbl3", ), ( - lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").distinct(), + lambda: select("x") + .with_("tbl", as_=select("x").from_("tbl2")) + .from_("tbl") + .distinct(), "WITH tbl AS (SELECT x FROM tbl2) SELECT DISTINCT x FROM tbl", ), ( - lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").where("x > 10"), + lambda: select("x") + .with_("tbl", as_=select("x").from_("tbl2")) + .from_("tbl") + .where("x > 10"), "WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl WHERE x > 10", ), ( - lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").having("x > 20"), + lambda: select("x") + .with_("tbl", as_=select("x").from_("tbl2")) + .from_("tbl") + .having("x > 20"), "WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl HAVING x > 20", ), (lambda: select("x").from_("tbl").subquery(), "(SELECT x FROM tbl)"), @@ -354,7 +401,9 @@ class TestBuild(unittest.TestCase): "SELECT x FROM (SELECT x FROM tbl) AS foo WHERE x > 0", ), ( - lambda: exp.subquery("select x from tbl UNION select x from bar", "unioned").select("x"), + lambda: exp.subquery("select x from tbl UNION select x from bar", "unioned").select( + "x" + ), "SELECT x FROM (SELECT x FROM tbl UNION SELECT x FROM bar) AS unioned", ), ( diff --git a/tests/test_executor.py b/tests/test_executor.py index ef1a706..49805b9 100644 --- a/tests/test_executor.py +++ b/tests/test_executor.py @@ -33,7 +33,10 @@ class TestExecutor(unittest.TestCase): ) cls.cache = {} - cls.sqls = [(sql, expected) for _, sql, expected in load_sql_fixture_pairs("optimizer/tpc-h/tpc-h.sql")] + cls.sqls = [ + (sql, expected) + for _, sql, expected in load_sql_fixture_pairs("optimizer/tpc-h/tpc-h.sql") + ] @classmethod def tearDownClass(cls): @@ -63,7 +66,9 @@ class TestExecutor(unittest.TestCase): def test_execute_tpch(self): def to_csv(expression): if isinstance(expression, exp.Table): - return parse_one(f"READ_CSV('{DIR}{expression.name}.csv.gz', 'delimiter', '|') AS {expression.name}") + return parse_one( + f"READ_CSV('{DIR}{expression.name}.csv.gz', 'delimiter', '|') AS {expression.name}" + ) return expression for sql, _ in self.sqls[0:3]: diff --git a/tests/test_expressions.py b/tests/test_expressions.py index adfd329..63371d8 100644 --- a/tests/test_expressions.py +++ b/tests/test_expressions.py @@ -30,7 +30,9 @@ class TestExpressions(unittest.TestCase): self.assertEqual(parse_one("TO_DATE(x)", read="hive"), parse_one("ts_or_ds_to_date(x)")) self.assertEqual(exp.Table(pivots=[]), exp.Table()) self.assertNotEqual(exp.Table(pivots=[None]), exp.Table()) - self.assertEqual(exp.DataType.build("int"), exp.DataType(this=exp.DataType.Type.INT, nested=False)) + self.assertEqual( + exp.DataType.build("int"), exp.DataType(this=exp.DataType.Type.INT, nested=False) + ) def test_find(self): expression = parse_one("CREATE TABLE x STORED AS PARQUET AS SELECT * FROM y") @@ -89,7 +91,9 @@ class TestExpressions(unittest.TestCase): self.assertIsNone(column.find_ancestor(exp.Join)) def test_alias_or_name(self): - expression = parse_one("SELECT a, b AS B, c + d AS e, *, 'zz', 'zz' AS z FROM foo as bar, baz") + expression = parse_one( + "SELECT a, b AS B, c + d AS e, *, 'zz', 'zz' AS z FROM foo as bar, baz" + ) self.assertEqual( [e.alias_or_name for e in expression.expressions], ["a", "B", "e", "*", "zz", "z"], @@ -166,7 +170,9 @@ class TestExpressions(unittest.TestCase): "SELECT * FROM foo WHERE ? > 100", ) self.assertEqual( - exp.replace_placeholders(parse_one("select * from :name WHERE ? > 100"), another_name="bla").sql(), + exp.replace_placeholders( + parse_one("select * from :name WHERE ? > 100"), another_name="bla" + ).sql(), "SELECT * FROM :name WHERE ? > 100", ) self.assertEqual( @@ -183,7 +189,9 @@ class TestExpressions(unittest.TestCase): ) def test_named_selects(self): - expression = parse_one("SELECT a, b AS B, c + d AS e, *, 'zz', 'zz' AS z FROM foo as bar, baz") + expression = parse_one( + "SELECT a, b AS B, c + d AS e, *, 'zz', 'zz' AS z FROM foo as bar, baz" + ) self.assertEqual(expression.named_selects, ["a", "B", "e", "*", "zz", "z"]) expression = parse_one( @@ -367,7 +375,9 @@ class TestExpressions(unittest.TestCase): self.assertEqual(len(list(expression.walk())), 9) self.assertEqual(len(list(expression.walk(bfs=False))), 9) self.assertTrue(all(isinstance(e, exp.Expression) for e, _, _ in expression.walk())) - self.assertTrue(all(isinstance(e, exp.Expression) for e, _, _ in expression.walk(bfs=False))) + self.assertTrue( + all(isinstance(e, exp.Expression) for e, _, _ in expression.walk(bfs=False)) + ) def test_functions(self): self.assertIsInstance(parse_one("ABS(a)"), exp.Abs) @@ -512,14 +522,21 @@ class TestExpressions(unittest.TestCase): ), exp.Properties( expressions=[ - exp.FileFormatProperty(this=exp.Literal.string("FORMAT"), value=exp.Literal.string("parquet")), + exp.FileFormatProperty( + this=exp.Literal.string("FORMAT"), value=exp.Literal.string("parquet") + ), exp.PartitionedByProperty( this=exp.Literal.string("PARTITIONED_BY"), - value=exp.Tuple(expressions=[exp.to_identifier("a"), exp.to_identifier("b")]), + value=exp.Tuple( + expressions=[exp.to_identifier("a"), exp.to_identifier("b")] + ), + ), + exp.AnonymousProperty( + this=exp.Literal.string("custom"), value=exp.Literal.number(1) ), - exp.AnonymousProperty(this=exp.Literal.string("custom"), value=exp.Literal.number(1)), exp.TableFormatProperty( - this=exp.Literal.string("TABLE_FORMAT"), value=exp.to_identifier("test_format") + this=exp.Literal.string("TABLE_FORMAT"), + value=exp.to_identifier("test_format"), ), exp.EngineProperty(this=exp.Literal.string("ENGINE"), value=exp.NULL), exp.CollateProperty(this=exp.Literal.string("COLLATE"), value=exp.TRUE), @@ -538,7 +555,10 @@ class TestExpressions(unittest.TestCase): ((1, "2", None), "(1, '2', NULL)"), ([1, "2", None], "ARRAY(1, '2', NULL)"), ({"x": None}, "MAP('x', NULL)"), - (datetime.datetime(2022, 10, 1, 1, 1, 1), "TIME_STR_TO_TIME('2022-10-01 01:01:01.000000')"), + ( + datetime.datetime(2022, 10, 1, 1, 1, 1), + "TIME_STR_TO_TIME('2022-10-01 01:01:01.000000')", + ), ( datetime.datetime(2022, 10, 1, 1, 1, 1, tzinfo=datetime.timezone.utc), "TIME_STR_TO_TIME('2022-10-01 01:01:01.000000+0000')", @@ -548,30 +568,48 @@ class TestExpressions(unittest.TestCase): with self.subTest(value): self.assertEqual(exp.convert(value).sql(), expected) - def test_annotation_alias(self): - sql = "SELECT a, b AS B, c # comment, d AS D # another_comment FROM foo" + def test_comment_alias(self): + sql = """ + SELECT + a, + b AS B, + c, /*comment*/ + d AS D, -- another comment + CAST(x AS INT) -- final comment + FROM foo + """ expression = parse_one(sql) self.assertEqual( [e.alias_or_name for e in expression.expressions], - ["a", "B", "c", "D"], + ["a", "B", "c", "D", "x"], + ) + self.assertEqual( + expression.sql(), + "SELECT a, b AS B, c /* comment */, d AS D /* another comment */, CAST(x AS INT) /* final comment */ FROM foo", + ) + self.assertEqual( + expression.sql(comments=False), + "SELECT a, b AS B, c, d AS D, CAST(x AS INT) FROM foo", ) - self.assertEqual(expression.sql(), "SELECT a, b AS B, c, d AS D") - self.assertEqual(expression.expressions[2].name, "comment") self.assertEqual( - expression.sql(pretty=True, annotations=False), + expression.sql(pretty=True, comments=False), """SELECT a, b AS B, c, - d AS D""", + d AS D, + CAST(x AS INT) +FROM foo""", ) self.assertEqual( expression.sql(pretty=True), """SELECT a, b AS B, - c # comment, - d AS D # another_comment FROM foo""", + c, -- comment + d AS D, -- another comment + CAST(x AS INT) -- final comment +FROM foo""", ) def test_to_table(self): @@ -605,5 +643,9 @@ class TestExpressions(unittest.TestCase): self.assertIsInstance(expression, exp.Union) self.assertEqual(expression.named_selects, ["cola", "colb"]) self.assertEqual( - expression.selects, [exp.Column(this=exp.to_identifier("cola")), exp.Column(this=exp.to_identifier("colb"))] + expression.selects, + [ + exp.Column(this=exp.to_identifier("cola")), + exp.Column(this=exp.to_identifier("colb")), + ], ) diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index 3b5990f..a1b7e70 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -67,7 +67,9 @@ class TestOptimizer(unittest.TestCase): } def check_file(self, file, func, pretty=False, execute=False, **kwargs): - for i, (meta, sql, expected) in enumerate(load_sql_fixture_pairs(f"optimizer/{file}.sql"), start=1): + for i, (meta, sql, expected) in enumerate( + load_sql_fixture_pairs(f"optimizer/{file}.sql"), start=1 + ): title = meta.get("title") or f"{i}, {sql}" dialect = meta.get("dialect") leave_tables_isolated = meta.get("leave_tables_isolated") @@ -90,7 +92,9 @@ class TestOptimizer(unittest.TestCase): if string_to_bool(should_execute): with self.subTest(f"(execute) {title}"): - df1 = self.conn.execute(sqlglot.transpile(sql, read=dialect, write="duckdb")[0]).df() + df1 = self.conn.execute( + sqlglot.transpile(sql, read=dialect, write="duckdb")[0] + ).df() df2 = self.conn.execute(optimized.sql(pretty=pretty, dialect="duckdb")).df() assert_frame_equal(df1, df2) @@ -268,7 +272,8 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') self.assertEqual(scopes[1].expression.sql(), "SELECT y.b FROM y") self.assertEqual(scopes[2].expression.sql(), "(VALUES (1, 'test')) AS tab(cola, colb)") self.assertEqual( - scopes[3].expression.sql(), "SELECT cola, colb FROM (VALUES (1, 'test')) AS tab(cola, colb)" + scopes[3].expression.sql(), + "SELECT cola, colb FROM (VALUES (1, 'test')) AS tab(cola, colb)", ) self.assertEqual(scopes[4].expression.sql(), "SELECT y.c AS b FROM y") self.assertEqual(scopes[5].expression.sql(), "SELECT MAX(x.a) FROM x WHERE x.b = s.b") @@ -287,7 +292,11 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') # Check that we can walk in scope from an arbitrary node self.assertEqual( - {node.sql() for node, *_ in walk_in_scope(expression.find(exp.Where)) if isinstance(node, exp.Column)}, + { + node.sql() + for node, *_ in walk_in_scope(expression.find(exp.Where)) + if isinstance(node, exp.Column) + }, {"s.b"}, ) @@ -324,7 +333,9 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') self.assertEqual(expression.args["to"].expressions[0].type, exp.DataType.Type.INT) def test_cache_annotation(self): - expression = annotate_types(parse_one("CACHE LAZY TABLE x OPTIONS('storageLevel' = 'value') AS SELECT 1")) + expression = annotate_types( + parse_one("CACHE LAZY TABLE x OPTIONS('storageLevel' = 'value') AS SELECT 1") + ) self.assertEqual(expression.expression.expressions[0].type, exp.DataType.Type.INT) def test_binary_annotation(self): @@ -384,7 +395,9 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') """ expression = annotate_types(parse_one(sql), schema=schema) - self.assertEqual(expression.expressions[0].type, exp.DataType.Type.TEXT) # tbl.cola + tbl.colb + 'foo' AS col + self.assertEqual( + expression.expressions[0].type, exp.DataType.Type.TEXT + ) # tbl.cola + tbl.colb + 'foo' AS col outer_addition = expression.expressions[0].this # (tbl.cola + tbl.colb) + 'foo' self.assertEqual(outer_addition.type, exp.DataType.Type.TEXT) @@ -396,7 +409,9 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') self.assertEqual(inner_addition.right.type, exp.DataType.Type.TEXT) cte_select = expression.args["with"].expressions[0].this - self.assertEqual(cte_select.expressions[0].type, exp.DataType.Type.VARCHAR) # x.cola + 'bla' AS cola + self.assertEqual( + cte_select.expressions[0].type, exp.DataType.Type.VARCHAR + ) # x.cola + 'bla' AS cola self.assertEqual(cte_select.expressions[1].type, exp.DataType.Type.TEXT) # y.colb AS colb cte_select_addition = cte_select.expressions[0].this # x.cola + 'bla' @@ -405,7 +420,9 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') self.assertEqual(cte_select_addition.right.type, exp.DataType.Type.VARCHAR) # Check that x.cola AS cola and y.colb AS colb have types CHAR and TEXT, respectively - for d, t in zip(cte_select.find_all(exp.Subquery), [exp.DataType.Type.CHAR, exp.DataType.Type.TEXT]): + for d, t in zip( + cte_select.find_all(exp.Subquery), [exp.DataType.Type.CHAR, exp.DataType.Type.TEXT] + ): self.assertEqual(d.this.expressions[0].this.type, t) def test_function_annotation(self): @@ -421,6 +438,19 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') self.assertEqual(concat_expr.right.type, exp.DataType.Type.VARCHAR) # TRIM(x.colb) self.assertEqual(concat_expr.right.this.type, exp.DataType.Type.CHAR) # x.colb + sql = "SELECT CASE WHEN 1=1 THEN x.cola ELSE x.colb END AS col FROM x AS x" + + case_expr_alias = annotate_types(parse_one(sql), schema=schema).expressions[0] + self.assertEqual(case_expr_alias.type, exp.DataType.Type.VARCHAR) + + case_expr = case_expr_alias.this + self.assertEqual(case_expr.type, exp.DataType.Type.VARCHAR) + self.assertEqual(case_expr.args["default"].type, exp.DataType.Type.CHAR) + + case_ifs_expr = case_expr.args["ifs"][0] + self.assertEqual(case_ifs_expr.type, exp.DataType.Type.VARCHAR) + self.assertEqual(case_ifs_expr.args["true"].type, exp.DataType.Type.VARCHAR) + def test_unknown_annotation(self): schema = {"x": {"cola": "VARCHAR"}} sql = "SELECT x.cola || SOME_ANONYMOUS_FUNC(x.cola) AS col FROM x AS x" @@ -431,8 +461,12 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') concat_expr = concat_expr_alias.this self.assertEqual(concat_expr.type, exp.DataType.Type.UNKNOWN) self.assertEqual(concat_expr.left.type, exp.DataType.Type.VARCHAR) # x.cola - self.assertEqual(concat_expr.right.type, exp.DataType.Type.UNKNOWN) # SOME_ANONYMOUS_FUNC(x.cola) - self.assertEqual(concat_expr.right.expressions[0].type, exp.DataType.Type.VARCHAR) # x.cola (arg) + self.assertEqual( + concat_expr.right.type, exp.DataType.Type.UNKNOWN + ) # SOME_ANONYMOUS_FUNC(x.cola) + self.assertEqual( + concat_expr.right.expressions[0].type, exp.DataType.Type.VARCHAR + ) # x.cola (arg) def test_null_annotation(self): expression = annotate_types(parse_one("SELECT NULL + 2 AS col")).expressions[0].this diff --git a/tests/test_parser.py b/tests/test_parser.py index 9afeae6..04c20b1 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -23,8 +23,6 @@ class TestParser(unittest.TestCase): def test_float(self): self.assertEqual(parse_one(".2"), parse_one("0.2")) - self.assertEqual(parse_one("int 1"), parse_one("CAST(1 AS INT)")) - self.assertEqual(parse_one("int.5"), parse_one("CAST(0.5 AS INT)")) def test_table(self): tables = [t.sql() for t in parse_one("select * from a, b.c, .d").find_all(exp.Table)] @@ -33,7 +31,9 @@ class TestParser(unittest.TestCase): def test_select(self): self.assertIsNotNone(parse_one("select 1 natural")) self.assertIsNotNone(parse_one("select * from (select 1) x order by x.y").args["order"]) - self.assertIsNotNone(parse_one("select * from x where a = (select 1) order by x.y").args["order"]) + self.assertIsNotNone( + parse_one("select * from x where a = (select 1) order by x.y").args["order"] + ) self.assertEqual(len(parse_one("select * from (select 1) x cross join y").args["joins"]), 1) self.assertEqual( parse_one("""SELECT * FROM x CROSS JOIN y, z LATERAL VIEW EXPLODE(y)""").sql(), @@ -125,26 +125,70 @@ class TestParser(unittest.TestCase): def test_var(self): self.assertEqual(parse_one("SELECT @JOIN, @'foo'").sql(), "SELECT @JOIN, @'foo'") - def test_annotations(self): + def test_comments(self): expression = parse_one( """ - SELECT - a #annotation1, - b as B #annotation2:testing , - "test#annotation",c#annotation3, d #annotation4, - e #, - f # space + --comment1 + SELECT /* this won't be used */ + a, --comment2 + b as B, --comment3:testing + "test--annotation", + c, --comment4 --foo + e, -- + f -- space FROM foo """ ) - assert expression.expressions[0].name == "annotation1" - assert expression.expressions[1].name == "annotation2:testing" - assert expression.expressions[2].name == "test#annotation" - assert expression.expressions[3].name == "annotation3" - assert expression.expressions[4].name == "annotation4" - assert expression.expressions[5].name == "" - assert expression.expressions[6].name == "space" + self.assertEqual(expression.comment, "comment1") + self.assertEqual(expression.expressions[0].comment, "comment2") + self.assertEqual(expression.expressions[1].comment, "comment3:testing") + self.assertEqual(expression.expressions[2].comment, None) + self.assertEqual(expression.expressions[3].comment, "comment4 --foo") + self.assertEqual(expression.expressions[4].comment, "") + self.assertEqual(expression.expressions[5].comment, " space") + + def test_type_literals(self): + self.assertEqual(parse_one("int 1"), parse_one("CAST(1 AS INT)")) + self.assertEqual(parse_one("int.5"), parse_one("CAST(0.5 AS INT)")) + self.assertEqual( + parse_one("TIMESTAMP '2022-01-01'").sql(), "CAST('2022-01-01' AS TIMESTAMP)" + ) + self.assertEqual( + parse_one("TIMESTAMP(1) '2022-01-01'").sql(), "CAST('2022-01-01' AS TIMESTAMP(1))" + ) + self.assertEqual( + parse_one("TIMESTAMP WITH TIME ZONE '2022-01-01'").sql(), + "CAST('2022-01-01' AS TIMESTAMPTZ)", + ) + self.assertEqual( + parse_one("TIMESTAMP WITH LOCAL TIME ZONE '2022-01-01'").sql(), + "CAST('2022-01-01' AS TIMESTAMPLTZ)", + ) + self.assertEqual( + parse_one("TIMESTAMP WITHOUT TIME ZONE '2022-01-01'").sql(), + "CAST('2022-01-01' AS TIMESTAMP)", + ) + self.assertEqual( + parse_one("TIMESTAMP(1) WITH TIME ZONE '2022-01-01'").sql(), + "CAST('2022-01-01' AS TIMESTAMPTZ(1))", + ) + self.assertEqual( + parse_one("TIMESTAMP(1) WITH LOCAL TIME ZONE '2022-01-01'").sql(), + "CAST('2022-01-01' AS TIMESTAMPLTZ(1))", + ) + self.assertEqual( + parse_one("TIMESTAMP(1) WITHOUT TIME ZONE '2022-01-01'").sql(), + "CAST('2022-01-01' AS TIMESTAMP(1))", + ) + self.assertEqual(parse_one("TIMESTAMP(1) WITH TIME ZONE").sql(), "TIMESTAMPTZ(1)") + self.assertEqual(parse_one("TIMESTAMP(1) WITH LOCAL TIME ZONE").sql(), "TIMESTAMPLTZ(1)") + self.assertEqual(parse_one("TIMESTAMP(1) WITHOUT TIME ZONE").sql(), "TIMESTAMP(1)") + self.assertEqual(parse_one("""JSON '{"x":"y"}'""").sql(), """CAST('{"x":"y"}' AS JSON)""") + self.assertIsInstance(parse_one("TIMESTAMP(1)"), exp.Func) + self.assertIsInstance(parse_one("TIMESTAMP('2022-01-01')"), exp.Func) + self.assertIsInstance(parse_one("TIMESTAMP()"), exp.Func) + self.assertIsInstance(parse_one("map.x"), exp.Column) def test_pretty_config_override(self): self.assertEqual(parse_one("SELECT col FROM x").sql(), "SELECT col FROM x") diff --git a/tests/test_schema.py b/tests/test_schema.py index bab97d8..cc0e3d1 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -1,281 +1,141 @@ import unittest -from sqlglot import table -from sqlglot.dataframe.sql import types as df_types +from sqlglot import exp, to_table +from sqlglot.errors import SchemaError from sqlglot.schema import MappingSchema, ensure_schema class TestSchema(unittest.TestCase): - def test_schema(self): - schema = ensure_schema( - { - "x": { - "a": "uint64", - } - } - ) - self.assertEqual( - schema.column_names( - table( - "x", - ) - ), - ["a"], - ) - with self.assertRaises(ValueError): - schema.column_names(table("x", db="db", catalog="c")) - with self.assertRaises(ValueError): - schema.column_names(table("x", db="db")) - with self.assertRaises(ValueError): - schema.column_names(table("x2")) + def assert_column_names(self, schema, *table_results): + for table, result in table_results: + with self.subTest(f"{table} -> {result}"): + self.assertEqual(schema.column_names(to_table(table)), result) - with self.assertRaises(ValueError): - schema.add_table(table("y", db="db"), {"b": "string"}) - with self.assertRaises(ValueError): - schema.add_table(table("y", db="db", catalog="c"), {"b": "string"}) + def assert_column_names_raises(self, schema, *tables): + for table in tables: + with self.subTest(table): + with self.assertRaises(SchemaError): + schema.column_names(to_table(table)) - schema.add_table(table("y"), {"b": "string"}) - schema_with_y = { - "x": { - "a": "uint64", - }, - "y": { - "b": "string", - }, - } - self.assertEqual(schema.schema, schema_with_y) - - new_schema = schema.copy() - new_schema.add_table(table("z"), {"c": "string"}) - self.assertEqual(schema.schema, schema_with_y) - self.assertEqual( - new_schema.schema, - { - "x": { - "a": "uint64", - }, - "y": { - "b": "string", - }, - "z": { - "c": "string", - }, - }, - ) - schema.add_table(table("m"), {"d": "string"}) - schema.add_table(table("n"), {"e": "string"}) - schema_with_m_n = { - "x": { - "a": "uint64", - }, - "y": { - "b": "string", - }, - "m": { - "d": "string", - }, - "n": { - "e": "string", - }, - } - self.assertEqual(schema.schema, schema_with_m_n) - new_schema = schema.copy() - new_schema.add_table(table("o"), {"f": "string"}) - new_schema.add_table(table("p"), {"g": "string"}) - self.assertEqual(schema.schema, schema_with_m_n) - self.assertEqual( - new_schema.schema, + def test_schema(self): + schema = ensure_schema( { "x": { "a": "uint64", }, "y": { - "b": "string", - }, - "m": { - "d": "string", - }, - "n": { - "e": "string", - }, - "o": { - "f": "string", - }, - "p": { - "g": "string", + "b": "uint64", + "c": "uint64", }, }, ) - schema = ensure_schema( - { - "db": { - "x": { - "a": "uint64", - } - } - } + self.assert_column_names( + schema, + ("x", ["a"]), + ("y", ["b", "c"]), + ("z.x", ["a"]), + ("z.x.y", ["b", "c"]), ) - self.assertEqual(schema.column_names(table("x", db="db")), ["a"]) - with self.assertRaises(ValueError): - schema.column_names(table("x", db="db", catalog="c")) - with self.assertRaises(ValueError): - schema.column_names(table("x")) - with self.assertRaises(ValueError): - schema.column_names(table("x", db="db2")) - with self.assertRaises(ValueError): - schema.column_names(table("x2", db="db")) - with self.assertRaises(ValueError): - schema.add_table(table("y"), {"b": "string"}) - with self.assertRaises(ValueError): - schema.add_table(table("y", db="db", catalog="c"), {"b": "string"}) + self.assert_column_names_raises( + schema, + "z", + "z.z", + "z.z.z", + ) - schema.add_table(table("y", db="db"), {"b": "string"}) - self.assertEqual( - schema.schema, + def test_schema_db(self): + schema = ensure_schema( { - "db": { + "d1": { "x": { "a": "uint64", }, "y": { - "b": "string", + "b": "uint64", + }, + }, + "d2": { + "x": { + "c": "uint64", }, - } + }, }, ) - schema = ensure_schema( - { - "c": { - "db": { - "x": { - "a": "uint64", - } - } - } - } + self.assert_column_names( + schema, + ("d1.x", ["a"]), + ("d2.x", ["c"]), + ("y", ["b"]), + ("d1.y", ["b"]), + ("z.d1.y", ["b"]), ) - self.assertEqual(schema.column_names(table("x", db="db", catalog="c")), ["a"]) - with self.assertRaises(ValueError): - schema.column_names(table("x", db="db")) - with self.assertRaises(ValueError): - schema.column_names(table("x")) - with self.assertRaises(ValueError): - schema.column_names(table("x", db="db", catalog="c2")) - with self.assertRaises(ValueError): - schema.column_names(table("x", db="db2")) - with self.assertRaises(ValueError): - schema.column_names(table("x2", db="db")) - with self.assertRaises(ValueError): - schema.add_table(table("x"), {"b": "string"}) - with self.assertRaises(ValueError): - schema.add_table(table("x", db="db"), {"b": "string"}) + self.assert_column_names_raises( + schema, + "x", + "z.x", + "z.y", + ) - schema.add_table(table("y", db="db", catalog="c"), {"a": "string", "b": "int"}) - self.assertEqual( - schema.schema, + def test_schema_catalog(self): + schema = ensure_schema( { - "c": { - "db": { + "c1": { + "d1": { "x": { "a": "uint64", }, "y": { - "a": "string", - "b": "int", + "b": "uint64", }, - } - } - }, - ) - schema.add_table(table("z", db="db2", catalog="c"), {"c": "string", "d": "int"}) - self.assertEqual( - schema.schema, - { - "c": { - "db": { - "x": { - "a": "uint64", + "z": { + "c": "uint64", }, + }, + }, + "c2": { + "d1": { "y": { - "a": "string", - "b": "int", + "d": "uint64", }, - }, - "db2": { "z": { - "c": "string", - "d": "int", - } - }, - } - }, - ) - schema.add_table(table("m", db="db2", catalog="c2"), {"e": "string", "f": "int"}) - self.assertEqual( - schema.schema, - { - "c": { - "db": { - "x": { - "a": "uint64", - }, - "y": { - "a": "string", - "b": "int", + "e": "uint64", }, }, - "db2": { + "d2": { "z": { - "c": "string", - "d": "int", - } + "f": "uint64", + }, }, }, - "c2": { - "db2": { - "m": { - "e": "string", - "f": "int", - } - } - }, - }, - ) - - schema = ensure_schema( - { - "x": { - "a": "uint64", - } } ) - self.assertEqual(schema.column_names(table("x")), ["a"]) - schema = MappingSchema() - schema.add_table(table("x"), {"a": "string"}) - self.assertEqual( - schema.schema, - { - "x": { - "a": "string", - } - }, + self.assert_column_names( + schema, + ("x", ["a"]), + ("d1.x", ["a"]), + ("c1.d1.x", ["a"]), + ("c1.d1.y", ["b"]), + ("c1.d1.z", ["c"]), + ("c2.d1.y", ["d"]), + ("c2.d1.z", ["e"]), + ("d2.z", ["f"]), + ("c2.d2.z", ["f"]), ) - schema.add_table(table("y"), df_types.StructType([df_types.StructField("b", df_types.StringType())])) - self.assertEqual( - schema.schema, - { - "x": { - "a": "string", - }, - "y": { - "b": "string", - }, - }, + + self.assert_column_names_raises( + schema, + "q", + "d2.x", + "y", + "z", + "d1.y", + "d1.z", + "a.b.c", ) def test_schema_add_table_with_and_without_mapping(self): @@ -288,3 +148,34 @@ class TestSchema(unittest.TestCase): self.assertEqual(schema.column_names("test"), ["x", "y"]) schema.add_table("test") self.assertEqual(schema.column_names("test"), ["x", "y"]) + + def test_schema_get_column_type(self): + schema = MappingSchema({"a": {"b": "varchar"}}) + self.assertEqual(schema.get_column_type("a", "b"), exp.DataType.Type.VARCHAR) + self.assertEqual( + schema.get_column_type(exp.Table(this="a"), exp.Column(this="b")), + exp.DataType.Type.VARCHAR, + ) + self.assertEqual( + schema.get_column_type("a", exp.Column(this="b")), exp.DataType.Type.VARCHAR + ) + self.assertEqual( + schema.get_column_type(exp.Table(this="a"), "b"), exp.DataType.Type.VARCHAR + ) + schema = MappingSchema({"a": {"b": {"c": "varchar"}}}) + self.assertEqual( + schema.get_column_type(exp.Table(this="b", db="a"), exp.Column(this="c")), + exp.DataType.Type.VARCHAR, + ) + self.assertEqual( + schema.get_column_type(exp.Table(this="b", db="a"), "c"), exp.DataType.Type.VARCHAR + ) + schema = MappingSchema({"a": {"b": {"c": {"d": "varchar"}}}}) + self.assertEqual( + schema.get_column_type(exp.Table(this="c", db="b", catalog="a"), exp.Column(this="d")), + exp.DataType.Type.VARCHAR, + ) + self.assertEqual( + schema.get_column_type(exp.Table(this="c", db="b", catalog="a"), "d"), + exp.DataType.Type.VARCHAR, + ) diff --git a/tests/test_tokens.py b/tests/test_tokens.py new file mode 100644 index 0000000..943c2b0 --- /dev/null +++ b/tests/test_tokens.py @@ -0,0 +1,18 @@ +import unittest + +from sqlglot.tokens import Tokenizer + + +class TestTokens(unittest.TestCase): + def test_comment_attachment(self): + tokenizer = Tokenizer() + sql_comment = [ + ("/*comment*/ foo", "comment"), + ("/*comment*/ foo --test", "comment"), + ("--comment\nfoo --test", "comment"), + ("foo --comment", "comment"), + ("foo", None), + ] + + for sql, comment in sql_comment: + self.assertEqual(tokenizer.tokenize(sql)[0].comment, comment) diff --git a/tests/test_transpile.py b/tests/test_transpile.py index 01b8205..942053e 100644 --- a/tests/test_transpile.py +++ b/tests/test_transpile.py @@ -49,6 +49,12 @@ class TestTranspile(unittest.TestCase): leading_comma=True, pretty=True, ) + self.validate( + "SELECT FOO, /*x*/\nBAR, /*y*/\nBAZ", + "SELECT\n FOO -- x\n , BAR -- y\n , BAZ", + leading_comma=True, + pretty=True, + ) # without pretty, this should be a no-op self.validate( "SELECT FOO, BAR, BAZ", @@ -63,24 +69,61 @@ class TestTranspile(unittest.TestCase): self.validate("SELECT 3>=3", "SELECT 3 >= 3") def test_comments(self): - self.validate("SELECT 1 FROM foo -- comment", "SELECT 1 FROM foo") - self.validate("SELECT 1 /* inline */ FROM foo -- comment", "SELECT 1 FROM foo") - + self.validate("SELECT 1 FROM foo -- comment", "SELECT 1 FROM foo /* comment */") + self.validate("SELECT --+5\nx FROM foo", "/* +5 */ SELECT x FROM foo") + self.validate("SELECT --!5\nx FROM foo", "/* !5 */ SELECT x FROM foo") + self.validate( + "SELECT 1 /* inline */ FROM foo -- comment", + "SELECT 1 /* inline */ FROM foo /* comment */", + ) + self.validate( + "SELECT FUN(x) /*x*/, [1,2,3] /*y*/", "SELECT FUN(x) /* x */, ARRAY(1, 2, 3) /* y */" + ) self.validate( """ SELECT 1 -- comment FROM foo -- comment """, - "SELECT 1 FROM foo", + "SELECT 1 /* comment */ FROM foo /* comment */", ) - self.validate( """ SELECT 1 /* big comment like this */ FROM foo -- comment """, - "SELECT 1 FROM foo", + """SELECT 1 /* big comment + like this */ FROM foo /* comment */""", + ) + self.validate( + "select x from foo -- x", + "SELECT x FROM foo /* x */", + ) + self.validate( + """ + /* multi + line + comment + */ + SELECT + tbl.cola /* comment 1 */ + tbl.colb /* comment 2 */, + CAST(x AS INT), # comment 3 + y -- comment 4 + FROM + bar /* comment 5 */, + tbl # comment 6 + """, + """/* multi + line + comment + */ +SELECT + tbl.cola /* comment 1 */ + tbl.colb /* comment 2 */, + CAST(x AS INT), -- comment 3 + y -- comment 4 +FROM bar /* comment 5 */, tbl /* comment 6 */""", + read="mysql", + pretty=True, ) def test_types(self): @@ -146,6 +189,16 @@ class TestTranspile(unittest.TestCase): def test_ignore_nulls(self): self.validate("SELECT COUNT(x RESPECT NULLS)", "SELECT COUNT(x)") + def test_with(self): + self.validate( + "WITH a AS (SELECT 1) WITH b AS (SELECT 2) SELECT *", + "WITH a AS (SELECT 1), b AS (SELECT 2) SELECT *", + ) + self.validate( + "WITH a AS (SELECT 1), WITH b AS (SELECT 2) SELECT *", + "WITH a AS (SELECT 1), b AS (SELECT 2) SELECT *", + ) + def test_time(self): self.validate("TIMESTAMP '2020-01-01'", "CAST('2020-01-01' AS TIMESTAMP)") self.validate("TIMESTAMP WITH TIME ZONE '2020-01-01'", "CAST('2020-01-01' AS TIMESTAMPTZ)") |