summaryrefslogtreecommitdiffstats
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/dataframe/integration/dataframe_validator.py52
-rw-r--r--tests/dataframe/integration/test_dataframe.py295
-rw-r--r--tests/dataframe/integration/test_session.py12
-rw-r--r--tests/dataframe/unit/dataframe_sql_validator.py12
-rw-r--r--tests/dataframe/unit/test_column.py14
-rw-r--r--tests/dataframe/unit/test_dataframe.py4
-rw-r--r--tests/dataframe/unit/test_functions.py55
-rw-r--r--tests/dataframe/unit/test_session.py11
-rw-r--r--tests/dataframe/unit/test_types.py5
-rw-r--r--tests/dataframe/unit/test_window.py32
-rw-r--r--tests/dialects/test_bigquery.py12
-rw-r--r--tests/dialects/test_clickhouse.py5
-rw-r--r--tests/dialects/test_databricks.py3
-rw-r--r--tests/dialects/test_dialect.py156
-rw-r--r--tests/dialects/test_mysql.py294
-rw-r--r--tests/dialects/test_postgres.py45
-rw-r--r--tests/dialects/test_redshift.py4
-rw-r--r--tests/dialects/test_snowflake.py117
-rw-r--r--tests/dialects/test_spark.py4
-rw-r--r--tests/dialects/test_starrocks.py3
-rw-r--r--tests/dialects/test_tsql.py29
-rw-r--r--tests/fixtures/identity.sql12
-rw-r--r--tests/fixtures/optimizer/qualify_columns.sql10
-rw-r--r--tests/fixtures/optimizer/simplify.sql9
-rw-r--r--tests/fixtures/pretty.sql28
-rw-r--r--tests/test_build.py87
-rw-r--r--tests/test_executor.py9
-rw-r--r--tests/test_expressions.py82
-rw-r--r--tests/test_optimizer.py54
-rw-r--r--tests/test_parser.py78
-rw-r--r--tests/test_schema.py345
-rw-r--r--tests/test_tokens.py18
-rw-r--r--tests/test_transpile.py65
33 files changed, 1443 insertions, 518 deletions
diff --git a/tests/dataframe/integration/dataframe_validator.py b/tests/dataframe/integration/dataframe_validator.py
index 4a89c78..16f8922 100644
--- a/tests/dataframe/integration/dataframe_validator.py
+++ b/tests/dataframe/integration/dataframe_validator.py
@@ -1,9 +1,9 @@
-import sys
import typing as t
import unittest
import warnings
import sqlglot
+from sqlglot.helper import PYTHON_VERSION
from tests.helpers import SKIP_INTEGRATION
if t.TYPE_CHECKING:
@@ -11,7 +11,8 @@ if t.TYPE_CHECKING:
@unittest.skipIf(
- SKIP_INTEGRATION or sys.version_info[:2] > (3, 10), "Skipping Integration Tests since `SKIP_INTEGRATION` is set"
+ SKIP_INTEGRATION or PYTHON_VERSION > (3, 10),
+ "Skipping Integration Tests since `SKIP_INTEGRATION` is set",
)
class DataFrameValidator(unittest.TestCase):
spark = None
@@ -36,7 +37,12 @@ class DataFrameValidator(unittest.TestCase):
# This is for test `test_branching_root_dataframes`
config = SparkConf().setAll([("spark.sql.analyzer.failAmbiguousSelfJoin", "false")])
- cls.spark = SparkSession.builder.master("local[*]").appName("Unit-tests").config(conf=config).getOrCreate()
+ cls.spark = (
+ SparkSession.builder.master("local[*]")
+ .appName("Unit-tests")
+ .config(conf=config)
+ .getOrCreate()
+ )
cls.spark.sparkContext.setLogLevel("ERROR")
cls.sqlglot = SqlglotSparkSession()
cls.spark_employee_schema = types.StructType(
@@ -50,7 +56,9 @@ class DataFrameValidator(unittest.TestCase):
)
cls.sqlglot_employee_schema = sqlglotSparkTypes.StructType(
[
- sqlglotSparkTypes.StructField("employee_id", sqlglotSparkTypes.IntegerType(), False),
+ sqlglotSparkTypes.StructField(
+ "employee_id", sqlglotSparkTypes.IntegerType(), False
+ ),
sqlglotSparkTypes.StructField("fname", sqlglotSparkTypes.StringType(), False),
sqlglotSparkTypes.StructField("lname", sqlglotSparkTypes.StringType(), False),
sqlglotSparkTypes.StructField("age", sqlglotSparkTypes.IntegerType(), False),
@@ -64,8 +72,12 @@ class DataFrameValidator(unittest.TestCase):
(4, "Claire", "Littleton", 27, 2),
(5, "Hugo", "Reyes", 29, 100),
]
- cls.df_employee = cls.spark.createDataFrame(data=employee_data, schema=cls.spark_employee_schema)
- cls.dfs_employee = cls.sqlglot.createDataFrame(data=employee_data, schema=cls.sqlglot_employee_schema)
+ cls.df_employee = cls.spark.createDataFrame(
+ data=employee_data, schema=cls.spark_employee_schema
+ )
+ cls.dfs_employee = cls.sqlglot.createDataFrame(
+ data=employee_data, schema=cls.sqlglot_employee_schema
+ )
cls.df_employee.createOrReplaceTempView("employee")
cls.spark_store_schema = types.StructType(
@@ -80,7 +92,9 @@ class DataFrameValidator(unittest.TestCase):
[
sqlglotSparkTypes.StructField("store_id", sqlglotSparkTypes.IntegerType(), False),
sqlglotSparkTypes.StructField("store_name", sqlglotSparkTypes.StringType(), False),
- sqlglotSparkTypes.StructField("district_id", sqlglotSparkTypes.IntegerType(), False),
+ sqlglotSparkTypes.StructField(
+ "district_id", sqlglotSparkTypes.IntegerType(), False
+ ),
sqlglotSparkTypes.StructField("num_sales", sqlglotSparkTypes.IntegerType(), False),
]
)
@@ -89,7 +103,9 @@ class DataFrameValidator(unittest.TestCase):
(2, "Arrow", 2, 2000),
]
cls.df_store = cls.spark.createDataFrame(data=store_data, schema=cls.spark_store_schema)
- cls.dfs_store = cls.sqlglot.createDataFrame(data=store_data, schema=cls.sqlglot_store_schema)
+ cls.dfs_store = cls.sqlglot.createDataFrame(
+ data=store_data, schema=cls.sqlglot_store_schema
+ )
cls.df_store.createOrReplaceTempView("store")
cls.spark_district_schema = types.StructType(
@@ -101,17 +117,27 @@ class DataFrameValidator(unittest.TestCase):
)
cls.sqlglot_district_schema = sqlglotSparkTypes.StructType(
[
- sqlglotSparkTypes.StructField("district_id", sqlglotSparkTypes.IntegerType(), False),
- sqlglotSparkTypes.StructField("district_name", sqlglotSparkTypes.StringType(), False),
- sqlglotSparkTypes.StructField("manager_name", sqlglotSparkTypes.StringType(), False),
+ sqlglotSparkTypes.StructField(
+ "district_id", sqlglotSparkTypes.IntegerType(), False
+ ),
+ sqlglotSparkTypes.StructField(
+ "district_name", sqlglotSparkTypes.StringType(), False
+ ),
+ sqlglotSparkTypes.StructField(
+ "manager_name", sqlglotSparkTypes.StringType(), False
+ ),
]
)
district_data = [
(1, "Temple", "Dogen"),
(2, "Lighthouse", "Jacob"),
]
- cls.df_district = cls.spark.createDataFrame(data=district_data, schema=cls.spark_district_schema)
- cls.dfs_district = cls.sqlglot.createDataFrame(data=district_data, schema=cls.sqlglot_district_schema)
+ cls.df_district = cls.spark.createDataFrame(
+ data=district_data, schema=cls.spark_district_schema
+ )
+ cls.dfs_district = cls.sqlglot.createDataFrame(
+ data=district_data, schema=cls.sqlglot_district_schema
+ )
cls.df_district.createOrReplaceTempView("district")
sqlglot.schema.add_table("employee", cls.sqlglot_employee_schema)
sqlglot.schema.add_table("store", cls.sqlglot_store_schema)
diff --git a/tests/dataframe/integration/test_dataframe.py b/tests/dataframe/integration/test_dataframe.py
index c740bec..19e3b89 100644
--- a/tests/dataframe/integration/test_dataframe.py
+++ b/tests/dataframe/integration/test_dataframe.py
@@ -41,22 +41,32 @@ class TestDataframeFunc(DataFrameValidator):
def test_alias_with_select(self):
df_employee = self.df_spark_employee.alias("df_employee").select(
- self.df_spark_employee["employee_id"], F.col("df_employee.fname"), self.df_spark_employee.lname
+ self.df_spark_employee["employee_id"],
+ F.col("df_employee.fname"),
+ self.df_spark_employee.lname,
)
dfs_employee = self.df_sqlglot_employee.alias("dfs_employee").select(
- self.df_sqlglot_employee["employee_id"], SF.col("dfs_employee.fname"), self.df_sqlglot_employee.lname
+ self.df_sqlglot_employee["employee_id"],
+ SF.col("dfs_employee.fname"),
+ self.df_sqlglot_employee.lname,
)
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
def test_case_when_otherwise(self):
df = self.df_spark_employee.select(
- F.when((F.col("age") >= F.lit(40)) & (F.col("age") <= F.lit(60)), F.lit("between 40 and 60"))
+ F.when(
+ (F.col("age") >= F.lit(40)) & (F.col("age") <= F.lit(60)),
+ F.lit("between 40 and 60"),
+ )
.when(F.col("age") < F.lit(40), "less than 40")
.otherwise("greater than 60")
)
dfs = self.df_sqlglot_employee.select(
- SF.when((SF.col("age") >= SF.lit(40)) & (SF.col("age") <= SF.lit(60)), SF.lit("between 40 and 60"))
+ SF.when(
+ (SF.col("age") >= SF.lit(40)) & (SF.col("age") <= SF.lit(60)),
+ SF.lit("between 40 and 60"),
+ )
.when(SF.col("age") < SF.lit(40), "less than 40")
.otherwise("greater than 60")
)
@@ -65,15 +75,17 @@ class TestDataframeFunc(DataFrameValidator):
def test_case_when_no_otherwise(self):
df = self.df_spark_employee.select(
- F.when((F.col("age") >= F.lit(40)) & (F.col("age") <= F.lit(60)), F.lit("between 40 and 60")).when(
- F.col("age") < F.lit(40), "less than 40"
- )
+ F.when(
+ (F.col("age") >= F.lit(40)) & (F.col("age") <= F.lit(60)),
+ F.lit("between 40 and 60"),
+ ).when(F.col("age") < F.lit(40), "less than 40")
)
dfs = self.df_sqlglot_employee.select(
- SF.when((SF.col("age") >= SF.lit(40)) & (SF.col("age") <= SF.lit(60)), SF.lit("between 40 and 60")).when(
- SF.col("age") < SF.lit(40), "less than 40"
- )
+ SF.when(
+ (SF.col("age") >= SF.lit(40)) & (SF.col("age") <= SF.lit(60)),
+ SF.lit("between 40 and 60"),
+ ).when(SF.col("age") < SF.lit(40), "less than 40")
)
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
@@ -84,7 +96,9 @@ class TestDataframeFunc(DataFrameValidator):
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
def test_where_clause_multiple_and(self):
- df_employee = self.df_spark_employee.where((F.col("age") == F.lit(37)) & (F.col("fname") == F.lit("Jack")))
+ df_employee = self.df_spark_employee.where(
+ (F.col("age") == F.lit(37)) & (F.col("fname") == F.lit("Jack"))
+ )
dfs_employee = self.df_sqlglot_employee.where(
(SF.col("age") == SF.lit(37)) & (SF.col("fname") == SF.lit("Jack"))
)
@@ -106,7 +120,9 @@ class TestDataframeFunc(DataFrameValidator):
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
def test_where_clause_multiple_or(self):
- df_employee = self.df_spark_employee.where((F.col("age") == F.lit(37)) | (F.col("fname") == F.lit("Kate")))
+ df_employee = self.df_spark_employee.where(
+ (F.col("age") == F.lit(37)) | (F.col("fname") == F.lit("Kate"))
+ )
dfs_employee = self.df_sqlglot_employee.where(
(SF.col("age") == SF.lit(37)) | (SF.col("fname") == SF.lit("Kate"))
)
@@ -172,28 +188,43 @@ class TestDataframeFunc(DataFrameValidator):
dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] == SF.lit(37))
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
- df_employee = self.df_spark_employee.where(self.df_spark_employee["age"] % F.lit(5) == F.lit(0))
- dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] % SF.lit(5) == SF.lit(0))
+ df_employee = self.df_spark_employee.where(
+ self.df_spark_employee["age"] % F.lit(5) == F.lit(0)
+ )
+ dfs_employee = self.df_sqlglot_employee.where(
+ self.df_sqlglot_employee["age"] % SF.lit(5) == SF.lit(0)
+ )
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
- df_employee = self.df_spark_employee.where(self.df_spark_employee["age"] + F.lit(5) > F.lit(28))
- dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] + SF.lit(5) > SF.lit(28))
+ df_employee = self.df_spark_employee.where(
+ self.df_spark_employee["age"] + F.lit(5) > F.lit(28)
+ )
+ dfs_employee = self.df_sqlglot_employee.where(
+ self.df_sqlglot_employee["age"] + SF.lit(5) > SF.lit(28)
+ )
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
- df_employee = self.df_spark_employee.where(self.df_spark_employee["age"] - F.lit(5) > F.lit(28))
- dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] - SF.lit(5) > SF.lit(28))
+ df_employee = self.df_spark_employee.where(
+ self.df_spark_employee["age"] - F.lit(5) > F.lit(28)
+ )
+ dfs_employee = self.df_sqlglot_employee.where(
+ self.df_sqlglot_employee["age"] - SF.lit(5) > SF.lit(28)
+ )
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
df_employee = self.df_spark_employee.where(
self.df_spark_employee["age"] * F.lit(0.5) == self.df_spark_employee["age"] / F.lit(2)
)
dfs_employee = self.df_sqlglot_employee.where(
- self.df_sqlglot_employee["age"] * SF.lit(0.5) == self.df_sqlglot_employee["age"] / SF.lit(2)
+ self.df_sqlglot_employee["age"] * SF.lit(0.5)
+ == self.df_sqlglot_employee["age"] / SF.lit(2)
)
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
def test_join_inner(self):
- df_joined = self.df_spark_employee.join(self.df_spark_store, on=["store_id"], how="inner").select(
+ df_joined = self.df_spark_employee.join(
+ self.df_spark_store, on=["store_id"], how="inner"
+ ).select(
self.df_spark_employee.employee_id,
self.df_spark_employee["fname"],
F.col("lname"),
@@ -202,7 +233,9 @@ class TestDataframeFunc(DataFrameValidator):
self.df_spark_store.store_name,
self.df_spark_store["num_sales"],
)
- dfs_joined = self.df_sqlglot_employee.join(self.df_sqlglot_store, on=["store_id"], how="inner").select(
+ dfs_joined = self.df_sqlglot_employee.join(
+ self.df_sqlglot_store, on=["store_id"], how="inner"
+ ).select(
self.df_sqlglot_employee.employee_id,
self.df_sqlglot_employee["fname"],
SF.col("lname"),
@@ -214,17 +247,27 @@ class TestDataframeFunc(DataFrameValidator):
self.compare_spark_with_sqlglot(df_joined, dfs_joined)
def test_join_inner_no_select(self):
- df_joined = self.df_spark_employee.select(F.col("store_id"), F.col("fname"), F.col("lname")).join(
- self.df_spark_store.select(F.col("store_id"), F.col("store_name")), on=["store_id"], how="inner"
+ df_joined = self.df_spark_employee.select(
+ F.col("store_id"), F.col("fname"), F.col("lname")
+ ).join(
+ self.df_spark_store.select(F.col("store_id"), F.col("store_name")),
+ on=["store_id"],
+ how="inner",
)
- dfs_joined = self.df_sqlglot_employee.select(SF.col("store_id"), SF.col("fname"), SF.col("lname")).join(
- self.df_sqlglot_store.select(SF.col("store_id"), SF.col("store_name")), on=["store_id"], how="inner"
+ dfs_joined = self.df_sqlglot_employee.select(
+ SF.col("store_id"), SF.col("fname"), SF.col("lname")
+ ).join(
+ self.df_sqlglot_store.select(SF.col("store_id"), SF.col("store_name")),
+ on=["store_id"],
+ how="inner",
)
self.compare_spark_with_sqlglot(df_joined, dfs_joined)
def test_join_inner_equality_single(self):
df_joined = self.df_spark_employee.join(
- self.df_spark_store, on=self.df_spark_employee.store_id == self.df_spark_store.store_id, how="inner"
+ self.df_spark_store,
+ on=self.df_spark_employee.store_id == self.df_spark_store.store_id,
+ how="inner",
).select(
self.df_spark_employee.employee_id,
self.df_spark_employee["fname"],
@@ -235,7 +278,9 @@ class TestDataframeFunc(DataFrameValidator):
self.df_spark_store["num_sales"],
)
dfs_joined = self.df_sqlglot_employee.join(
- self.df_sqlglot_store, on=self.df_sqlglot_employee.store_id == self.df_sqlglot_store.store_id, how="inner"
+ self.df_sqlglot_store,
+ on=self.df_sqlglot_employee.store_id == self.df_sqlglot_store.store_id,
+ how="inner",
).select(
self.df_sqlglot_employee.employee_id,
self.df_sqlglot_employee["fname"],
@@ -343,7 +388,9 @@ class TestDataframeFunc(DataFrameValidator):
self.compare_spark_with_sqlglot(df_joined, dfs_joined)
def test_join_full_outer(self):
- df_joined = self.df_spark_employee.join(self.df_spark_store, on=["store_id"], how="full_outer").select(
+ df_joined = self.df_spark_employee.join(
+ self.df_spark_store, on=["store_id"], how="full_outer"
+ ).select(
self.df_spark_employee.employee_id,
self.df_spark_employee["fname"],
F.col("lname"),
@@ -352,7 +399,9 @@ class TestDataframeFunc(DataFrameValidator):
self.df_spark_store.store_name,
self.df_spark_store["num_sales"],
)
- dfs_joined = self.df_sqlglot_employee.join(self.df_sqlglot_store, on=["store_id"], how="full_outer").select(
+ dfs_joined = self.df_sqlglot_employee.join(
+ self.df_sqlglot_store, on=["store_id"], how="full_outer"
+ ).select(
self.df_sqlglot_employee.employee_id,
self.df_sqlglot_employee["fname"],
SF.col("lname"),
@@ -365,7 +414,9 @@ class TestDataframeFunc(DataFrameValidator):
def test_triple_join(self):
df = (
- self.df_employee.join(self.df_store, on=self.df_employee.employee_id == self.df_store.store_id)
+ self.df_employee.join(
+ self.df_store, on=self.df_employee.employee_id == self.df_store.store_id
+ )
.join(self.df_district, on=self.df_store.store_id == self.df_district.district_id)
.select(
self.df_employee.employee_id,
@@ -377,7 +428,9 @@ class TestDataframeFunc(DataFrameValidator):
)
)
dfs = (
- self.dfs_employee.join(self.dfs_store, on=self.dfs_employee.employee_id == self.dfs_store.store_id)
+ self.dfs_employee.join(
+ self.dfs_store, on=self.dfs_employee.employee_id == self.dfs_store.store_id
+ )
.join(self.dfs_district, on=self.dfs_store.store_id == self.dfs_district.district_id)
.select(
self.dfs_employee.employee_id,
@@ -391,13 +444,13 @@ class TestDataframeFunc(DataFrameValidator):
self.compare_spark_with_sqlglot(df, dfs)
def test_join_select_and_select_start(self):
- df = self.df_spark_employee.select(F.col("fname"), F.col("lname"), F.col("age"), F.col("store_id")).join(
- self.df_spark_store, "store_id", "inner"
- )
+ df = self.df_spark_employee.select(
+ F.col("fname"), F.col("lname"), F.col("age"), F.col("store_id")
+ ).join(self.df_spark_store, "store_id", "inner")
- dfs = self.df_sqlglot_employee.select(SF.col("fname"), SF.col("lname"), SF.col("age"), SF.col("store_id")).join(
- self.df_sqlglot_store, "store_id", "inner"
- )
+ dfs = self.df_sqlglot_employee.select(
+ SF.col("fname"), SF.col("lname"), SF.col("age"), SF.col("store_id")
+ ).join(self.df_sqlglot_store, "store_id", "inner")
self.compare_spark_with_sqlglot(df, dfs)
@@ -485,13 +538,17 @@ class TestDataframeFunc(DataFrameValidator):
dfs_unioned = (
self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("fname"))
.unionAll(self.df_sqlglot_store.select(SF.col("store_id"), SF.col("store_name")))
- .unionAll(self.df_sqlglot_district.select(SF.col("district_id"), SF.col("district_name")))
+ .unionAll(
+ self.df_sqlglot_district.select(SF.col("district_id"), SF.col("district_name"))
+ )
)
self.compare_spark_with_sqlglot(df_unioned, dfs_unioned)
def test_union_by_name(self):
- df = self.df_spark_employee.select(F.col("employee_id"), F.col("fname"), F.col("lname")).unionByName(
+ df = self.df_spark_employee.select(
+ F.col("employee_id"), F.col("fname"), F.col("lname")
+ ).unionByName(
self.df_spark_store.select(
F.col("store_name").alias("lname"),
F.col("store_id").alias("employee_id"),
@@ -499,7 +556,9 @@ class TestDataframeFunc(DataFrameValidator):
)
)
- dfs = self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("fname"), SF.col("lname")).unionByName(
+ dfs = self.df_sqlglot_employee.select(
+ SF.col("employee_id"), SF.col("fname"), SF.col("lname")
+ ).unionByName(
self.df_sqlglot_store.select(
SF.col("store_name").alias("lname"),
SF.col("store_id").alias("employee_id"),
@@ -537,10 +596,16 @@ class TestDataframeFunc(DataFrameValidator):
self.compare_spark_with_sqlglot(df, dfs)
def test_order_by_default(self):
- df = self.df_spark_store.groupBy(F.col("district_id")).agg(F.min("num_sales")).orderBy(F.col("district_id"))
+ df = (
+ self.df_spark_store.groupBy(F.col("district_id"))
+ .agg(F.min("num_sales"))
+ .orderBy(F.col("district_id"))
+ )
dfs = (
- self.df_sqlglot_store.groupBy(SF.col("district_id")).agg(SF.min("num_sales")).orderBy(SF.col("district_id"))
+ self.df_sqlglot_store.groupBy(SF.col("district_id"))
+ .agg(SF.min("num_sales"))
+ .orderBy(SF.col("district_id"))
)
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
@@ -594,13 +659,17 @@ class TestDataframeFunc(DataFrameValidator):
df = (
self.df_spark_store.groupBy(F.col("district_id"))
.agg(F.min("num_sales").alias("total_sales"))
- .orderBy(F.when(F.col("district_id") == F.lit(2), F.col("district_id")).asc_nulls_last())
+ .orderBy(
+ F.when(F.col("district_id") == F.lit(2), F.col("district_id")).asc_nulls_last()
+ )
)
dfs = (
self.df_sqlglot_store.groupBy(SF.col("district_id"))
.agg(SF.min("num_sales").alias("total_sales"))
- .orderBy(SF.when(SF.col("district_id") == SF.lit(2), SF.col("district_id")).asc_nulls_last())
+ .orderBy(
+ SF.when(SF.col("district_id") == SF.lit(2), SF.col("district_id")).asc_nulls_last()
+ )
)
self.compare_spark_with_sqlglot(df, dfs)
@@ -609,81 +678,87 @@ class TestDataframeFunc(DataFrameValidator):
df = (
self.df_spark_store.groupBy(F.col("district_id"))
.agg(F.min("num_sales").alias("total_sales"))
- .orderBy(F.when(F.col("district_id") == F.lit(1), F.col("district_id")).desc_nulls_first())
+ .orderBy(
+ F.when(F.col("district_id") == F.lit(1), F.col("district_id")).desc_nulls_first()
+ )
)
dfs = (
self.df_sqlglot_store.groupBy(SF.col("district_id"))
.agg(SF.min("num_sales").alias("total_sales"))
- .orderBy(SF.when(SF.col("district_id") == SF.lit(1), SF.col("district_id")).desc_nulls_first())
+ .orderBy(
+ SF.when(
+ SF.col("district_id") == SF.lit(1), SF.col("district_id")
+ ).desc_nulls_first()
+ )
)
self.compare_spark_with_sqlglot(df, dfs)
def test_intersect(self):
- df_employee_duplicate = self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")).union(
- self.df_spark_employee.select(F.col("employee_id"), F.col("store_id"))
- )
+ df_employee_duplicate = self.df_spark_employee.select(
+ F.col("employee_id"), F.col("store_id")
+ ).union(self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")))
- df_store_duplicate = self.df_spark_store.select(F.col("store_id"), F.col("district_id")).union(
- self.df_spark_store.select(F.col("store_id"), F.col("district_id"))
- )
+ df_store_duplicate = self.df_spark_store.select(
+ F.col("store_id"), F.col("district_id")
+ ).union(self.df_spark_store.select(F.col("store_id"), F.col("district_id")))
df = df_employee_duplicate.intersect(df_store_duplicate)
- dfs_employee_duplicate = self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")).union(
- self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id"))
- )
+ dfs_employee_duplicate = self.df_sqlglot_employee.select(
+ SF.col("employee_id"), SF.col("store_id")
+ ).union(self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")))
- dfs_store_duplicate = self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")).union(
- self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id"))
- )
+ dfs_store_duplicate = self.df_sqlglot_store.select(
+ SF.col("store_id"), SF.col("district_id")
+ ).union(self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")))
dfs = dfs_employee_duplicate.intersect(dfs_store_duplicate)
self.compare_spark_with_sqlglot(df, dfs)
def test_intersect_all(self):
- df_employee_duplicate = self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")).union(
- self.df_spark_employee.select(F.col("employee_id"), F.col("store_id"))
- )
+ df_employee_duplicate = self.df_spark_employee.select(
+ F.col("employee_id"), F.col("store_id")
+ ).union(self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")))
- df_store_duplicate = self.df_spark_store.select(F.col("store_id"), F.col("district_id")).union(
- self.df_spark_store.select(F.col("store_id"), F.col("district_id"))
- )
+ df_store_duplicate = self.df_spark_store.select(
+ F.col("store_id"), F.col("district_id")
+ ).union(self.df_spark_store.select(F.col("store_id"), F.col("district_id")))
df = df_employee_duplicate.intersectAll(df_store_duplicate)
- dfs_employee_duplicate = self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")).union(
- self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id"))
- )
+ dfs_employee_duplicate = self.df_sqlglot_employee.select(
+ SF.col("employee_id"), SF.col("store_id")
+ ).union(self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")))
- dfs_store_duplicate = self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")).union(
- self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id"))
- )
+ dfs_store_duplicate = self.df_sqlglot_store.select(
+ SF.col("store_id"), SF.col("district_id")
+ ).union(self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")))
dfs = dfs_employee_duplicate.intersectAll(dfs_store_duplicate)
self.compare_spark_with_sqlglot(df, dfs)
def test_except_all(self):
- df_employee_duplicate = self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")).union(
- self.df_spark_employee.select(F.col("employee_id"), F.col("store_id"))
- )
+ df_employee_duplicate = self.df_spark_employee.select(
+ F.col("employee_id"), F.col("store_id")
+ ).union(self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")))
- df_store_duplicate = self.df_spark_store.select(F.col("store_id"), F.col("district_id")).union(
- self.df_spark_store.select(F.col("store_id"), F.col("district_id"))
- )
+ df_store_duplicate = self.df_spark_store.select(
+ F.col("store_id"), F.col("district_id")
+ ).union(self.df_spark_store.select(F.col("store_id"), F.col("district_id")))
df = df_employee_duplicate.exceptAll(df_store_duplicate)
- dfs_employee_duplicate = self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")).union(
- self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id"))
- )
+ dfs_employee_duplicate = self.df_sqlglot_employee.select(
+ SF.col("employee_id"), SF.col("store_id")
+ ).union(self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")))
- dfs_store_duplicate = self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")).union(
- self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id"))
- )
+ dfs_store_duplicate = self.df_sqlglot_store.select(
+ SF.col("store_id"), SF.col("district_id")
+ ).union(self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")))
dfs = dfs_employee_duplicate.exceptAll(dfs_store_duplicate)
@@ -721,7 +796,9 @@ class TestDataframeFunc(DataFrameValidator):
self.compare_spark_with_sqlglot(df, dfs)
def test_drop_na_default(self):
- df = self.df_spark_employee.select(F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")).dropna()
+ df = self.df_spark_employee.select(
+ F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")
+ ).dropna()
dfs = self.df_sqlglot_employee.select(
SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age")
@@ -746,7 +823,9 @@ class TestDataframeFunc(DataFrameValidator):
).dropna(how="any", thresh=2)
dfs = self.df_sqlglot_employee.select(
- SF.lit(None), SF.lit(1), SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age")
+ SF.lit(None),
+ SF.lit(1),
+ SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age"),
).dropna(how="any", thresh=2)
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
@@ -757,13 +836,17 @@ class TestDataframeFunc(DataFrameValidator):
).dropna(thresh=1, subset="the_age")
dfs = self.df_sqlglot_employee.select(
- SF.lit(None), SF.lit(1), SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age")
+ SF.lit(None),
+ SF.lit(1),
+ SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age"),
).dropna(thresh=1, subset="the_age")
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
def test_dropna_na_function(self):
- df = self.df_spark_employee.select(F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")).na.drop()
+ df = self.df_spark_employee.select(
+ F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")
+ ).na.drop()
dfs = self.df_sqlglot_employee.select(
SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age")
@@ -772,7 +855,9 @@ class TestDataframeFunc(DataFrameValidator):
self.compare_spark_with_sqlglot(df, dfs)
def test_fillna_default(self):
- df = self.df_spark_employee.select(F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")).fillna(100)
+ df = self.df_spark_employee.select(
+ F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")
+ ).fillna(100)
dfs = self.df_sqlglot_employee.select(
SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age")
@@ -798,7 +883,9 @@ class TestDataframeFunc(DataFrameValidator):
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
def test_fillna_na_func(self):
- df = self.df_spark_employee.select(F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")).na.fill(100)
+ df = self.df_spark_employee.select(
+ F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")
+ ).na.fill(100)
dfs = self.df_sqlglot_employee.select(
SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age")
@@ -807,7 +894,9 @@ class TestDataframeFunc(DataFrameValidator):
self.compare_spark_with_sqlglot(df, dfs)
def test_replace_basic(self):
- df = self.df_spark_employee.select(F.col("age"), F.lit(37).alias("test_col")).replace(to_replace=37, value=100)
+ df = self.df_spark_employee.select(F.col("age"), F.lit(37).alias("test_col")).replace(
+ to_replace=37, value=100
+ )
dfs = self.df_sqlglot_employee.select(SF.col("age"), SF.lit(37).alias("test_col")).replace(
to_replace=37, value=100
@@ -827,9 +916,13 @@ class TestDataframeFunc(DataFrameValidator):
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
def test_replace_mapping(self):
- df = self.df_spark_employee.select(F.col("age"), F.lit(37).alias("test_col")).replace({37: 100})
+ df = self.df_spark_employee.select(F.col("age"), F.lit(37).alias("test_col")).replace(
+ {37: 100}
+ )
- dfs = self.df_sqlglot_employee.select(SF.col("age"), SF.lit(37).alias("test_col")).replace({37: 100})
+ dfs = self.df_sqlglot_employee.select(SF.col("age"), SF.lit(37).alias("test_col")).replace(
+ {37: 100}
+ )
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
@@ -849,9 +942,9 @@ class TestDataframeFunc(DataFrameValidator):
to_replace=37, value=100
)
- dfs = self.df_sqlglot_employee.select(SF.col("age"), SF.lit(37).alias("test_col")).na.replace(
- to_replace=37, value=100
- )
+ dfs = self.df_sqlglot_employee.select(
+ SF.col("age"), SF.lit(37).alias("test_col")
+ ).na.replace(to_replace=37, value=100)
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
@@ -881,16 +974,18 @@ class TestDataframeFunc(DataFrameValidator):
"first_name", "first_name_again"
)
- dfs = self.df_sqlglot_employee.select(SF.col("fname").alias("first_name")).withColumnRenamed(
- "first_name", "first_name_again"
- )
+ dfs = self.df_sqlglot_employee.select(
+ SF.col("fname").alias("first_name")
+ ).withColumnRenamed("first_name", "first_name_again")
self.compare_spark_with_sqlglot(df, dfs)
def test_drop_column_single(self):
df = self.df_spark_employee.select(F.col("fname"), F.col("lname"), F.col("age")).drop("age")
- dfs = self.df_sqlglot_employee.select(SF.col("fname"), SF.col("lname"), SF.col("age")).drop("age")
+ dfs = self.df_sqlglot_employee.select(SF.col("fname"), SF.col("lname"), SF.col("age")).drop(
+ "age"
+ )
self.compare_spark_with_sqlglot(df, dfs)
@@ -906,7 +1001,9 @@ class TestDataframeFunc(DataFrameValidator):
df_sqlglot_employee_cols = self.df_sqlglot_employee.select(
SF.col("fname"), SF.col("lname"), SF.col("age"), SF.col("store_id")
)
- df_sqlglot_store_cols = self.df_sqlglot_store.select(SF.col("store_id"), SF.col("store_name"))
+ df_sqlglot_store_cols = self.df_sqlglot_store.select(
+ SF.col("store_id"), SF.col("store_name")
+ )
dfs = df_sqlglot_employee_cols.join(df_sqlglot_store_cols, on="store_id", how="inner").drop(
df_sqlglot_employee_cols.age,
)
diff --git a/tests/dataframe/integration/test_session.py b/tests/dataframe/integration/test_session.py
index ff1477b..ec50034 100644
--- a/tests/dataframe/integration/test_session.py
+++ b/tests/dataframe/integration/test_session.py
@@ -23,6 +23,14 @@ class TestSessionFunc(DataFrameValidator):
ON
e.store_id = s.store_id
"""
- df = self.spark.sql(query).groupBy(F.col("store_id")).agg(F.countDistinct(F.col("employee_id")))
- dfs = self.sqlglot.sql(query).groupBy(SF.col("store_id")).agg(SF.countDistinct(SF.col("employee_id")))
+ df = (
+ self.spark.sql(query)
+ .groupBy(F.col("store_id"))
+ .agg(F.countDistinct(F.col("employee_id")))
+ )
+ dfs = (
+ self.sqlglot.sql(query)
+ .groupBy(SF.col("store_id"))
+ .agg(SF.countDistinct(SF.col("employee_id")))
+ )
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
diff --git a/tests/dataframe/unit/dataframe_sql_validator.py b/tests/dataframe/unit/dataframe_sql_validator.py
index fc56553..32ff8f2 100644
--- a/tests/dataframe/unit/dataframe_sql_validator.py
+++ b/tests/dataframe/unit/dataframe_sql_validator.py
@@ -25,11 +25,17 @@ class DataFrameSQLValidator(unittest.TestCase):
(4, "Claire", "Littleton", 27, 2),
(5, "Hugo", "Reyes", 29, 100),
]
- self.df_employee = self.spark.createDataFrame(data=employee_data, schema=self.employee_schema)
+ self.df_employee = self.spark.createDataFrame(
+ data=employee_data, schema=self.employee_schema
+ )
- def compare_sql(self, df: DataFrame, expected_statements: t.Union[str, t.List[str]], pretty=False):
+ def compare_sql(
+ self, df: DataFrame, expected_statements: t.Union[str, t.List[str]], pretty=False
+ ):
actual_sqls = df.sql(pretty=pretty)
- expected_statements = [expected_statements] if isinstance(expected_statements, str) else expected_statements
+ expected_statements = (
+ [expected_statements] if isinstance(expected_statements, str) else expected_statements
+ )
self.assertEqual(len(expected_statements), len(actual_sqls))
for expected, actual in zip(expected_statements, actual_sqls):
self.assertEqual(expected, actual)
diff --git a/tests/dataframe/unit/test_column.py b/tests/dataframe/unit/test_column.py
index 977971e..da18502 100644
--- a/tests/dataframe/unit/test_column.py
+++ b/tests/dataframe/unit/test_column.py
@@ -26,12 +26,14 @@ class TestDataframeColumn(unittest.TestCase):
def test_and(self):
self.assertEqual(
- "cola = colb AND colc = cold", ((F.col("cola") == F.col("colb")) & (F.col("colc") == F.col("cold"))).sql()
+ "cola = colb AND colc = cold",
+ ((F.col("cola") == F.col("colb")) & (F.col("colc") == F.col("cold"))).sql(),
)
def test_or(self):
self.assertEqual(
- "cola = colb OR colc = cold", ((F.col("cola") == F.col("colb")) | (F.col("colc") == F.col("cold"))).sql()
+ "cola = colb OR colc = cold",
+ ((F.col("cola") == F.col("colb")) | (F.col("colc") == F.col("cold"))).sql(),
)
def test_mod(self):
@@ -112,7 +114,9 @@ class TestDataframeColumn(unittest.TestCase):
def test_when_otherwise(self):
self.assertEqual("CASE WHEN cola = 1 THEN 2 END", F.when(F.col("cola") == 1, 2).sql())
- self.assertEqual("CASE WHEN cola = 1 THEN 2 END", F.col("cola").when(F.col("cola") == 1, 2).sql())
+ self.assertEqual(
+ "CASE WHEN cola = 1 THEN 2 END", F.col("cola").when(F.col("cola") == 1, 2).sql()
+ )
self.assertEqual(
"CASE WHEN cola = 1 THEN 2 WHEN colb = 2 THEN 3 END",
(F.when(F.col("cola") == 1, 2).when(F.col("colb") == 2, 3)).sql(),
@@ -148,7 +152,9 @@ class TestDataframeColumn(unittest.TestCase):
self.assertEqual(
"cola BETWEEN CAST('2022-01-01 01:01:01.000000' AS TIMESTAMP) "
"AND CAST('2022-03-01 01:01:01.000000' AS TIMESTAMP)",
- F.col("cola").between(datetime.datetime(2022, 1, 1, 1, 1, 1), datetime.datetime(2022, 3, 1, 1, 1, 1)).sql(),
+ F.col("cola")
+ .between(datetime.datetime(2022, 1, 1, 1, 1, 1), datetime.datetime(2022, 3, 1, 1, 1, 1))
+ .sql(),
)
def test_over(self):
diff --git a/tests/dataframe/unit/test_dataframe.py b/tests/dataframe/unit/test_dataframe.py
index c222cac..e36667b 100644
--- a/tests/dataframe/unit/test_dataframe.py
+++ b/tests/dataframe/unit/test_dataframe.py
@@ -9,7 +9,9 @@ class TestDataframe(DataFrameSQLValidator):
self.assertEqual("t17051", DataFrame._create_hash_from_expression(expression))
def test_columns(self):
- self.assertEqual(["employee_id", "fname", "lname", "age", "store_id"], self.df_employee.columns)
+ self.assertEqual(
+ ["employee_id", "fname", "lname", "age", "store_id"], self.df_employee.columns
+ )
def test_cache(self):
df = self.df_employee.select("fname").cache()
diff --git a/tests/dataframe/unit/test_functions.py b/tests/dataframe/unit/test_functions.py
index eadbb93..8e5e5cd 100644
--- a/tests/dataframe/unit/test_functions.py
+++ b/tests/dataframe/unit/test_functions.py
@@ -925,12 +925,17 @@ class TestFunctions(unittest.TestCase):
col = SF.window(SF.col("cola"), "10 minutes")
self.assertEqual("WINDOW(cola, '10 minutes')", col.sql())
col_all_values = SF.window("cola", "2 minutes 30 seconds", "30 seconds", "15 seconds")
- self.assertEqual("WINDOW(cola, '2 minutes 30 seconds', '30 seconds', '15 seconds')", col_all_values.sql())
+ self.assertEqual(
+ "WINDOW(cola, '2 minutes 30 seconds', '30 seconds', '15 seconds')", col_all_values.sql()
+ )
col_no_start_time = SF.window("cola", "2 minutes 30 seconds", "30 seconds")
- self.assertEqual("WINDOW(cola, '2 minutes 30 seconds', '30 seconds')", col_no_start_time.sql())
+ self.assertEqual(
+ "WINDOW(cola, '2 minutes 30 seconds', '30 seconds')", col_no_start_time.sql()
+ )
col_no_slide = SF.window("cola", "2 minutes 30 seconds", startTime="15 seconds")
self.assertEqual(
- "WINDOW(cola, '2 minutes 30 seconds', '2 minutes 30 seconds', '15 seconds')", col_no_slide.sql()
+ "WINDOW(cola, '2 minutes 30 seconds', '2 minutes 30 seconds', '15 seconds')",
+ col_no_slide.sql(),
)
def test_session_window(self):
@@ -1359,9 +1364,13 @@ class TestFunctions(unittest.TestCase):
def test_from_json(self):
col_str = SF.from_json("cola", "cola INT", dict(timestampFormat="dd/MM/yyyy"))
- self.assertEqual("FROM_JSON(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql())
+ self.assertEqual(
+ "FROM_JSON(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql()
+ )
col = SF.from_json(SF.col("cola"), "cola INT", dict(timestampFormat="dd/MM/yyyy"))
- self.assertEqual("FROM_JSON(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col.sql())
+ self.assertEqual(
+ "FROM_JSON(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col.sql()
+ )
col_no_option = SF.from_json("cola", "cola INT")
self.assertEqual("FROM_JSON(cola, 'cola INT')", col_no_option.sql())
@@ -1375,7 +1384,9 @@ class TestFunctions(unittest.TestCase):
def test_schema_of_json(self):
col_str = SF.schema_of_json("cola", dict(timestampFormat="dd/MM/yyyy"))
- self.assertEqual("SCHEMA_OF_JSON(cola, MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql())
+ self.assertEqual(
+ "SCHEMA_OF_JSON(cola, MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql()
+ )
col = SF.schema_of_json(SF.col("cola"), dict(timestampFormat="dd/MM/yyyy"))
self.assertEqual("SCHEMA_OF_JSON(cola, MAP('timestampFormat', 'dd/MM/yyyy'))", col.sql())
col_no_option = SF.schema_of_json("cola")
@@ -1429,7 +1440,10 @@ class TestFunctions(unittest.TestCase):
col = SF.array_sort(SF.col("cola"))
self.assertEqual("ARRAY_SORT(cola)", col.sql())
col_comparator = SF.array_sort(
- "cola", lambda x, y: SF.when(x.isNull() | y.isNull(), SF.lit(0)).otherwise(SF.length(y) - SF.length(x))
+ "cola",
+ lambda x, y: SF.when(x.isNull() | y.isNull(), SF.lit(0)).otherwise(
+ SF.length(y) - SF.length(x)
+ ),
)
self.assertEqual(
"ARRAY_SORT(cola, (x, y) -> CASE WHEN x IS NULL OR y IS NULL THEN 0 ELSE LENGTH(y) - LENGTH(x) END)",
@@ -1504,9 +1518,13 @@ class TestFunctions(unittest.TestCase):
def test_from_csv(self):
col_str = SF.from_csv("cola", "cola INT", dict(timestampFormat="dd/MM/yyyy"))
- self.assertEqual("FROM_CSV(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql())
+ self.assertEqual(
+ "FROM_CSV(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql()
+ )
col = SF.from_csv(SF.col("cola"), "cola INT", dict(timestampFormat="dd/MM/yyyy"))
- self.assertEqual("FROM_CSV(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col.sql())
+ self.assertEqual(
+ "FROM_CSV(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col.sql()
+ )
col_no_option = SF.from_csv("cola", "cola INT")
self.assertEqual("FROM_CSV(cola, 'cola INT')", col_no_option.sql())
@@ -1535,7 +1553,9 @@ class TestFunctions(unittest.TestCase):
self.assertEqual("TRANSFORM(cola, (x, i) -> x * i)", col.sql())
col_custom_names = SF.transform("cola", lambda target, row_count: target * row_count)
- self.assertEqual("TRANSFORM(cola, (target, row_count) -> target * row_count)", col_custom_names.sql())
+ self.assertEqual(
+ "TRANSFORM(cola, (target, row_count) -> target * row_count)", col_custom_names.sql()
+ )
def test_exists(self):
col_str = SF.exists("cola", lambda x: x % 2 == 0)
@@ -1558,10 +1578,13 @@ class TestFunctions(unittest.TestCase):
self.assertEqual("FILTER(cola, x -> MONTH(TO_DATE(x)) > 6)", col_str.sql())
col = SF.filter(SF.col("cola"), lambda x, i: SF.month(SF.to_date(x)) > SF.lit(i))
self.assertEqual("FILTER(cola, (x, i) -> MONTH(TO_DATE(x)) > i)", col.sql())
- col_custom_names = SF.filter("cola", lambda target, row_count: SF.month(SF.to_date(target)) > SF.lit(row_count))
+ col_custom_names = SF.filter(
+ "cola", lambda target, row_count: SF.month(SF.to_date(target)) > SF.lit(row_count)
+ )
self.assertEqual(
- "FILTER(cola, (target, row_count) -> MONTH(TO_DATE(target)) > row_count)", col_custom_names.sql()
+ "FILTER(cola, (target, row_count) -> MONTH(TO_DATE(target)) > row_count)",
+ col_custom_names.sql(),
)
def test_zip_with(self):
@@ -1570,7 +1593,9 @@ class TestFunctions(unittest.TestCase):
col = SF.zip_with(SF.col("cola"), SF.col("colb"), lambda x, y: SF.concat_ws("_", x, y))
self.assertEqual("ZIP_WITH(cola, colb, (x, y) -> CONCAT_WS('_', x, y))", col.sql())
col_custom_names = SF.zip_with("cola", "colb", lambda l, r: SF.concat_ws("_", l, r))
- self.assertEqual("ZIP_WITH(cola, colb, (l, r) -> CONCAT_WS('_', l, r))", col_custom_names.sql())
+ self.assertEqual(
+ "ZIP_WITH(cola, colb, (l, r) -> CONCAT_WS('_', l, r))", col_custom_names.sql()
+ )
def test_transform_keys(self):
col_str = SF.transform_keys("cola", lambda k, v: SF.upper(k))
@@ -1586,7 +1611,9 @@ class TestFunctions(unittest.TestCase):
col = SF.transform_values(SF.col("cola"), lambda k, v: SF.upper(v))
self.assertEqual("TRANSFORM_VALUES(cola, (k, v) -> UPPER(v))", col.sql())
col_custom_names = SF.transform_values("cola", lambda _, value: SF.upper(value))
- self.assertEqual("TRANSFORM_VALUES(cola, (_, value) -> UPPER(value))", col_custom_names.sql())
+ self.assertEqual(
+ "TRANSFORM_VALUES(cola, (_, value) -> UPPER(value))", col_custom_names.sql()
+ )
def test_map_filter(self):
col_str = SF.map_filter("cola", lambda k, v: k > v)
diff --git a/tests/dataframe/unit/test_session.py b/tests/dataframe/unit/test_session.py
index 158dcec..7e8bfad 100644
--- a/tests/dataframe/unit/test_session.py
+++ b/tests/dataframe/unit/test_session.py
@@ -21,9 +21,7 @@ class TestDataframeSession(DataFrameSQLValidator):
def test_cdf_no_schema(self):
df = self.spark.createDataFrame([[1, 2], [3, 4], [None, 6]])
- expected = (
- "SELECT `a2`.`_1` AS `_1`, `a2`.`_2` AS `_2` FROM (VALUES (1, 2), (3, 4), (NULL, 6)) AS `a2`(`_1`, `_2`)"
- )
+ expected = "SELECT `a2`.`_1` AS `_1`, `a2`.`_2` AS `_2` FROM (VALUES (1, 2), (3, 4), (NULL, 6)) AS `a2`(`_1`, `_2`)"
self.compare_sql(df, expected)
def test_cdf_row_mixed_primitives(self):
@@ -77,7 +75,8 @@ class TestDataframeSession(DataFrameSQLValidator):
sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"})
df = self.spark.sql(query)
self.assertIn(
- "SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`", df.sql(pretty=False)
+ "SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`",
+ df.sql(pretty=False),
)
@mock.patch("sqlglot.schema", MappingSchema())
@@ -104,9 +103,7 @@ class TestDataframeSession(DataFrameSQLValidator):
query = "WITH t1 AS (SELECT cola, colb FROM table) INSERT INTO new_table SELECT cola, colb FROM t1"
sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"})
df = self.spark.sql(query)
- expected = (
- "INSERT INTO new_table SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`"
- )
+ expected = "INSERT INTO new_table SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`"
self.compare_sql(df, expected)
def test_session_create_builder_patterns(self):
diff --git a/tests/dataframe/unit/test_types.py b/tests/dataframe/unit/test_types.py
index 1f6c5dc..52f5d72 100644
--- a/tests/dataframe/unit/test_types.py
+++ b/tests/dataframe/unit/test_types.py
@@ -53,7 +53,10 @@ class TestDataframeTypes(unittest.TestCase):
self.assertEqual("array<int>", types.ArrayType(types.IntegerType()).simpleString())
def test_map(self):
- self.assertEqual("map<int, string>", types.MapType(types.IntegerType(), types.StringType()).simpleString())
+ self.assertEqual(
+ "map<int, string>",
+ types.MapType(types.IntegerType(), types.StringType()).simpleString(),
+ )
def test_struct_field(self):
self.assertEqual("cola:int", types.StructField("cola", types.IntegerType()).simpleString())
diff --git a/tests/dataframe/unit/test_window.py b/tests/dataframe/unit/test_window.py
index eea4582..70a868a 100644
--- a/tests/dataframe/unit/test_window.py
+++ b/tests/dataframe/unit/test_window.py
@@ -39,22 +39,38 @@ class TestDataframeWindow(unittest.TestCase):
def test_window_rows_unbounded(self):
rows_between_unbounded_start = Window.rowsBetween(Window.unboundedPreceding, 2)
- self.assertEqual("OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)", rows_between_unbounded_start.sql())
+ self.assertEqual(
+ "OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)",
+ rows_between_unbounded_start.sql(),
+ )
rows_between_unbounded_end = Window.rowsBetween(1, Window.unboundedFollowing)
- self.assertEqual("OVER ( ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)", rows_between_unbounded_end.sql())
- rows_between_unbounded_both = Window.rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)
self.assertEqual(
- "OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)", rows_between_unbounded_both.sql()
+ "OVER ( ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)",
+ rows_between_unbounded_end.sql(),
+ )
+ rows_between_unbounded_both = Window.rowsBetween(
+ Window.unboundedPreceding, Window.unboundedFollowing
+ )
+ self.assertEqual(
+ "OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)",
+ rows_between_unbounded_both.sql(),
)
def test_window_range_unbounded(self):
range_between_unbounded_start = Window.rangeBetween(Window.unboundedPreceding, 2)
self.assertEqual(
- "OVER ( RANGE BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)", range_between_unbounded_start.sql()
+ "OVER ( RANGE BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)",
+ range_between_unbounded_start.sql(),
)
range_between_unbounded_end = Window.rangeBetween(1, Window.unboundedFollowing)
- self.assertEqual("OVER ( RANGE BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)", range_between_unbounded_end.sql())
- range_between_unbounded_both = Window.rangeBetween(Window.unboundedPreceding, Window.unboundedFollowing)
self.assertEqual(
- "OVER ( RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)", range_between_unbounded_both.sql()
+ "OVER ( RANGE BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)",
+ range_between_unbounded_end.sql(),
+ )
+ range_between_unbounded_both = Window.rangeBetween(
+ Window.unboundedPreceding, Window.unboundedFollowing
+ )
+ self.assertEqual(
+ "OVER ( RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)",
+ range_between_unbounded_both.sql(),
)
diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py
index 050d41e..a0ebc45 100644
--- a/tests/dialects/test_bigquery.py
+++ b/tests/dialects/test_bigquery.py
@@ -157,6 +157,14 @@ class TestBigQuery(Validator):
},
)
+ self.validate_all(
+ "DIV(x, y)",
+ write={
+ "bigquery": "DIV(x, y)",
+ "duckdb": "CAST(x / y AS INT)",
+ },
+ )
+
self.validate_identity(
"SELECT ROW() OVER (y ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) FROM x WINDOW y AS (PARTITION BY CATEGORY)"
)
@@ -284,4 +292,6 @@ class TestBigQuery(Validator):
"CREATE TEMPORARY FUNCTION a(x FLOAT64, y FLOAT64) RETURNS FLOAT64 NOT DETERMINISTIC LANGUAGE js AS 'return x*y;'"
)
self.validate_identity("CREATE TEMPORARY FUNCTION a(x FLOAT64, y FLOAT64) AS ((x + 4) / y)")
- self.validate_identity("CREATE TABLE FUNCTION a(x INT64) RETURNS TABLE <q STRING, r INT64> AS SELECT s, t")
+ self.validate_identity(
+ "CREATE TABLE FUNCTION a(x INT64) RETURNS TABLE <q STRING, r INT64> AS SELECT s, t"
+ )
diff --git a/tests/dialects/test_clickhouse.py b/tests/dialects/test_clickhouse.py
index 715bf10..efb41bb 100644
--- a/tests/dialects/test_clickhouse.py
+++ b/tests/dialects/test_clickhouse.py
@@ -18,7 +18,6 @@ class TestClickhouse(Validator):
"spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname NULLS LAST",
},
)
-
self.validate_all(
"CAST(1 AS NULLABLE(Int64))",
write={
@@ -31,3 +30,7 @@ class TestClickhouse(Validator):
"clickhouse": "CAST(1 AS Nullable(DateTime64(6, 'UTC')))",
},
)
+ self.validate_all(
+ "SELECT x #! comment",
+ write={"": "SELECT x /* comment */"},
+ )
diff --git a/tests/dialects/test_databricks.py b/tests/dialects/test_databricks.py
index e242e73..2168f55 100644
--- a/tests/dialects/test_databricks.py
+++ b/tests/dialects/test_databricks.py
@@ -22,7 +22,8 @@ class TestDatabricks(Validator):
},
)
self.validate_all(
- "SELECT DATEDIFF('end', 'start')", write={"databricks": "SELECT DATEDIFF(DAY, 'start', 'end')"}
+ "SELECT DATEDIFF('end', 'start')",
+ write={"databricks": "SELECT DATEDIFF(DAY, 'start', 'end')"},
)
self.validate_all(
"SELECT DATE_ADD('2020-01-01', 1)",
diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py
index 3b837df..1913f53 100644
--- a/tests/dialects/test_dialect.py
+++ b/tests/dialects/test_dialect.py
@@ -1,20 +1,18 @@
import unittest
-from sqlglot import (
- Dialect,
- Dialects,
- ErrorLevel,
- UnsupportedError,
- parse_one,
- transpile,
-)
+from sqlglot import Dialect, Dialects, ErrorLevel, UnsupportedError, parse_one
class Validator(unittest.TestCase):
dialect = None
- def validate_identity(self, sql):
- self.assertEqual(transpile(sql, read=self.dialect, write=self.dialect)[0], sql)
+ def parse_one(self, sql):
+ return parse_one(sql, read=self.dialect)
+
+ def validate_identity(self, sql, write_sql=None):
+ expression = self.parse_one(sql)
+ self.assertEqual(write_sql or sql, expression.sql(dialect=self.dialect))
+ return expression
def validate_all(self, sql, read=None, write=None, pretty=False):
"""
@@ -28,12 +26,14 @@ class Validator(unittest.TestCase):
read (dict): Mapping of dialect -> SQL
write (dict): Mapping of dialect -> SQL
"""
- expression = parse_one(sql, read=self.dialect)
+ expression = self.parse_one(sql)
for read_dialect, read_sql in (read or {}).items():
with self.subTest(f"{read_dialect} -> {sql}"):
self.assertEqual(
- parse_one(read_sql, read_dialect).sql(self.dialect, unsupported_level=ErrorLevel.IGNORE),
+ parse_one(read_sql, read_dialect).sql(
+ self.dialect, unsupported_level=ErrorLevel.IGNORE, pretty=pretty
+ ),
sql,
)
@@ -83,10 +83,6 @@ class TestDialect(Validator):
)
self.validate_all(
"CAST(a AS BINARY(4))",
- read={
- "presto": "CAST(a AS VARBINARY(4))",
- "sqlite": "CAST(a AS VARBINARY(4))",
- },
write={
"bigquery": "CAST(a AS BINARY(4))",
"clickhouse": "CAST(a AS BINARY(4))",
@@ -104,6 +100,24 @@ class TestDialect(Validator):
},
)
self.validate_all(
+ "CAST(a AS VARBINARY(4))",
+ write={
+ "bigquery": "CAST(a AS VARBINARY(4))",
+ "clickhouse": "CAST(a AS VARBINARY(4))",
+ "duckdb": "CAST(a AS VARBINARY(4))",
+ "mysql": "CAST(a AS VARBINARY(4))",
+ "hive": "CAST(a AS BINARY(4))",
+ "oracle": "CAST(a AS BLOB(4))",
+ "postgres": "CAST(a AS BYTEA(4))",
+ "presto": "CAST(a AS VARBINARY(4))",
+ "redshift": "CAST(a AS VARBYTE(4))",
+ "snowflake": "CAST(a AS VARBINARY(4))",
+ "sqlite": "CAST(a AS BLOB(4))",
+ "spark": "CAST(a AS BINARY(4))",
+ "starrocks": "CAST(a AS VARBINARY(4))",
+ },
+ )
+ self.validate_all(
"CAST(MAP('a', '1') AS MAP(TEXT, TEXT))",
write={
"clickhouse": "CAST(map('a', '1') AS Map(TEXT, TEXT))",
@@ -472,45 +486,57 @@ class TestDialect(Validator):
},
)
self.validate_all(
- "DATE_TRUNC(x, 'day')",
+ "DATE_TRUNC('day', x)",
write={
"mysql": "DATE(x)",
- "starrocks": "DATE(x)",
},
)
self.validate_all(
- "DATE_TRUNC(x, 'week')",
+ "DATE_TRUNC('week', x)",
write={
"mysql": "STR_TO_DATE(CONCAT(YEAR(x), ' ', WEEK(x, 1), ' 1'), '%Y %u %w')",
- "starrocks": "STR_TO_DATE(CONCAT(YEAR(x), ' ', WEEK(x, 1), ' 1'), '%Y %u %w')",
},
)
self.validate_all(
- "DATE_TRUNC(x, 'month')",
+ "DATE_TRUNC('month', x)",
write={
"mysql": "STR_TO_DATE(CONCAT(YEAR(x), ' ', MONTH(x), ' 1'), '%Y %c %e')",
- "starrocks": "STR_TO_DATE(CONCAT(YEAR(x), ' ', MONTH(x), ' 1'), '%Y %c %e')",
},
)
self.validate_all(
- "DATE_TRUNC(x, 'quarter')",
+ "DATE_TRUNC('quarter', x)",
write={
"mysql": "STR_TO_DATE(CONCAT(YEAR(x), ' ', QUARTER(x) * 3 - 2, ' 1'), '%Y %c %e')",
- "starrocks": "STR_TO_DATE(CONCAT(YEAR(x), ' ', QUARTER(x) * 3 - 2, ' 1'), '%Y %c %e')",
},
)
self.validate_all(
- "DATE_TRUNC(x, 'year')",
+ "DATE_TRUNC('year', x)",
write={
"mysql": "STR_TO_DATE(CONCAT(YEAR(x), ' 1 1'), '%Y %c %e')",
- "starrocks": "STR_TO_DATE(CONCAT(YEAR(x), ' 1 1'), '%Y %c %e')",
},
)
self.validate_all(
- "DATE_TRUNC(x, 'millenium')",
+ "DATE_TRUNC('millenium', x)",
write={
"mysql": UnsupportedError,
- "starrocks": UnsupportedError,
+ },
+ )
+ self.validate_all(
+ "DATE_TRUNC('year', x)",
+ read={
+ "starrocks": "DATE_TRUNC('year', x)",
+ },
+ write={
+ "starrocks": "DATE_TRUNC('year', x)",
+ },
+ )
+ self.validate_all(
+ "DATE_TRUNC(x, year)",
+ read={
+ "bigquery": "DATE_TRUNC(x, year)",
+ },
+ write={
+ "bigquery": "DATE_TRUNC(x, year)",
},
)
self.validate_all(
@@ -564,6 +590,22 @@ class TestDialect(Validator):
"spark": "DATE_ADD(CAST('2020-01-01' AS DATE), 1)",
},
)
+ self.validate_all(
+ "TIMESTAMP '2022-01-01'",
+ write={
+ "mysql": "CAST('2022-01-01' AS TIMESTAMP)",
+ "starrocks": "CAST('2022-01-01' AS DATETIME)",
+ "hive": "CAST('2022-01-01' AS TIMESTAMP)",
+ },
+ )
+ self.validate_all(
+ "TIMESTAMP('2022-01-01')",
+ write={
+ "mysql": "TIMESTAMP('2022-01-01')",
+ "starrocks": "TIMESTAMP('2022-01-01')",
+ "hive": "TIMESTAMP('2022-01-01')",
+ },
+ )
for unit in ("DAY", "MONTH", "YEAR"):
self.validate_all(
@@ -1002,7 +1044,10 @@ class TestDialect(Validator):
)
def test_limit(self):
- self.validate_all("SELECT * FROM data LIMIT 10, 20", write={"sqlite": "SELECT * FROM data LIMIT 10 OFFSET 20"})
+ self.validate_all(
+ "SELECT * FROM data LIMIT 10, 20",
+ write={"sqlite": "SELECT * FROM data LIMIT 10 OFFSET 20"},
+ )
self.validate_all(
"SELECT x FROM y LIMIT 10",
write={
@@ -1132,3 +1177,56 @@ class TestDialect(Validator):
"sqlite": "WITH cte1 AS (SELECT a, b FROM table1), cte2 AS (SELECT c, e AS d FROM table2) SELECT b, d AS dd FROM cte1 AS t JOIN cte2 WHERE cte1.a = cte2.c",
},
)
+
+ def test_nullsafe_eq(self):
+ self.validate_all(
+ "SELECT a IS NOT DISTINCT FROM b",
+ read={
+ "mysql": "SELECT a <=> b",
+ "postgres": "SELECT a IS NOT DISTINCT FROM b",
+ },
+ write={
+ "mysql": "SELECT a <=> b",
+ "postgres": "SELECT a IS NOT DISTINCT FROM b",
+ },
+ )
+
+ def test_nullsafe_neq(self):
+ self.validate_all(
+ "SELECT a IS DISTINCT FROM b",
+ read={
+ "postgres": "SELECT a IS DISTINCT FROM b",
+ },
+ write={
+ "mysql": "SELECT NOT a <=> b",
+ "postgres": "SELECT a IS DISTINCT FROM b",
+ },
+ )
+
+ def test_hash_comments(self):
+ self.validate_all(
+ "SELECT 1 /* arbitrary content,,, until end-of-line */",
+ read={
+ "mysql": "SELECT 1 # arbitrary content,,, until end-of-line",
+ "bigquery": "SELECT 1 # arbitrary content,,, until end-of-line",
+ "clickhouse": "SELECT 1 #! arbitrary content,,, until end-of-line",
+ },
+ )
+ self.validate_all(
+ """/* comment1 */
+SELECT
+ x, -- comment2
+ y -- comment3""",
+ read={
+ "mysql": """SELECT # comment1
+ x, # comment2
+ y # comment3""",
+ "bigquery": """SELECT # comment1
+ x, # comment2
+ y # comment3""",
+ "clickhouse": """SELECT # comment1
+ x, # comment2
+ y # comment3""",
+ },
+ pretty=True,
+ )
diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py
index a25871c..1ba118b 100644
--- a/tests/dialects/test_mysql.py
+++ b/tests/dialects/test_mysql.py
@@ -1,3 +1,4 @@
+from sqlglot import expressions as exp
from tests.dialects.test_dialect import Validator
@@ -20,6 +21,52 @@ class TestMySQL(Validator):
self.validate_identity("SELECT TRIM(TRAILING 'bla' FROM ' XXX ')")
self.validate_identity("SELECT TRIM(BOTH 'bla' FROM ' XXX ')")
self.validate_identity("SELECT TRIM('bla' FROM ' XXX ')")
+ self.validate_identity("@@GLOBAL.max_connections")
+
+ # SET Commands
+ self.validate_identity("SET @var_name = expr")
+ self.validate_identity("SET @name = 43")
+ self.validate_identity("SET @total_tax = (SELECT SUM(tax) FROM taxable_transactions)")
+ self.validate_identity("SET GLOBAL max_connections = 1000")
+ self.validate_identity("SET @@GLOBAL.max_connections = 1000")
+ self.validate_identity("SET SESSION sql_mode = 'TRADITIONAL'")
+ self.validate_identity("SET LOCAL sql_mode = 'TRADITIONAL'")
+ self.validate_identity("SET @@SESSION.sql_mode = 'TRADITIONAL'")
+ self.validate_identity("SET @@LOCAL.sql_mode = 'TRADITIONAL'")
+ self.validate_identity("SET @@sql_mode = 'TRADITIONAL'")
+ self.validate_identity("SET sql_mode = 'TRADITIONAL'")
+ self.validate_identity("SET PERSIST max_connections = 1000")
+ self.validate_identity("SET @@PERSIST.max_connections = 1000")
+ self.validate_identity("SET PERSIST_ONLY back_log = 100")
+ self.validate_identity("SET @@PERSIST_ONLY.back_log = 100")
+ self.validate_identity("SET @@SESSION.max_join_size = DEFAULT")
+ self.validate_identity("SET @@SESSION.max_join_size = @@GLOBAL.max_join_size")
+ self.validate_identity("SET @x = 1, SESSION sql_mode = ''")
+ self.validate_identity(
+ "SET GLOBAL sort_buffer_size = 1000000, SESSION sort_buffer_size = 1000000"
+ )
+ self.validate_identity(
+ "SET @@GLOBAL.sort_buffer_size = 1000000, @@LOCAL.sort_buffer_size = 1000000"
+ )
+ self.validate_identity("SET GLOBAL max_connections = 1000, sort_buffer_size = 1000000")
+ self.validate_identity("SET @@GLOBAL.sort_buffer_size = 50000, sort_buffer_size = 1000000")
+ self.validate_identity("SET CHARACTER SET 'utf8'")
+ self.validate_identity("SET CHARACTER SET utf8")
+ self.validate_identity("SET CHARACTER SET DEFAULT")
+ self.validate_identity("SET NAMES 'utf8'")
+ self.validate_identity("SET NAMES DEFAULT")
+ self.validate_identity("SET NAMES 'utf8' COLLATE 'utf8_unicode_ci'")
+ self.validate_identity("SET NAMES utf8 COLLATE utf8_unicode_ci")
+ self.validate_identity("SET autocommit = ON")
+
+ def test_escape(self):
+ self.validate_all(
+ r"'a \' b '' '",
+ write={
+ "mysql": r"'a '' b '' '",
+ "spark": r"'a \' b \' '",
+ },
+ )
def test_introducers(self):
self.validate_all(
@@ -115,14 +162,6 @@ class TestMySQL(Validator):
},
)
- def test_hash_comments(self):
- self.validate_all(
- "SELECT 1 # arbitrary content,,, until end-of-line",
- write={
- "mysql": "SELECT 1",
- },
- )
-
def test_mysql(self):
self.validate_all(
"GROUP_CONCAT(DISTINCT x ORDER BY y DESC)",
@@ -174,3 +213,242 @@ COMMENT='客户账户表'"""
},
pretty=True,
)
+
+ def test_show_simple(self):
+ for key, write_key in [
+ ("BINARY LOGS", "BINARY LOGS"),
+ ("MASTER LOGS", "BINARY LOGS"),
+ ("STORAGE ENGINES", "ENGINES"),
+ ("ENGINES", "ENGINES"),
+ ("EVENTS", "EVENTS"),
+ ("MASTER STATUS", "MASTER STATUS"),
+ ("PLUGINS", "PLUGINS"),
+ ("PRIVILEGES", "PRIVILEGES"),
+ ("PROFILES", "PROFILES"),
+ ("REPLICAS", "REPLICAS"),
+ ("SLAVE HOSTS", "REPLICAS"),
+ ]:
+ show = self.validate_identity(f"SHOW {key}", f"SHOW {write_key}")
+ self.assertIsInstance(show, exp.Show)
+ self.assertEqual(show.name, write_key)
+
+ def test_show_events(self):
+ for key in ["BINLOG", "RELAYLOG"]:
+ show = self.validate_identity(f"SHOW {key} EVENTS")
+ self.assertIsInstance(show, exp.Show)
+ self.assertEqual(show.name, f"{key} EVENTS")
+
+ show = self.validate_identity(f"SHOW {key} EVENTS IN 'log' FROM 1 LIMIT 2, 3")
+ self.assertEqual(show.text("log"), "log")
+ self.assertEqual(show.text("position"), "1")
+ self.assertEqual(show.text("limit"), "3")
+ self.assertEqual(show.text("offset"), "2")
+
+ show = self.validate_identity(f"SHOW {key} EVENTS LIMIT 1")
+ self.assertEqual(show.text("limit"), "1")
+ self.assertIsNone(show.args.get("offset"))
+
+ def test_show_like_or_where(self):
+ for key, write_key in [
+ ("CHARSET", "CHARACTER SET"),
+ ("CHARACTER SET", "CHARACTER SET"),
+ ("COLLATION", "COLLATION"),
+ ("DATABASES", "DATABASES"),
+ ("FUNCTION STATUS", "FUNCTION STATUS"),
+ ("PROCEDURE STATUS", "PROCEDURE STATUS"),
+ ("GLOBAL STATUS", "GLOBAL STATUS"),
+ ("SESSION STATUS", "STATUS"),
+ ("STATUS", "STATUS"),
+ ("GLOBAL VARIABLES", "GLOBAL VARIABLES"),
+ ("SESSION VARIABLES", "VARIABLES"),
+ ("VARIABLES", "VARIABLES"),
+ ]:
+ expected_name = write_key.strip("GLOBAL").strip()
+ template = "SHOW {}"
+ show = self.validate_identity(template.format(key), template.format(write_key))
+ self.assertIsInstance(show, exp.Show)
+ self.assertEqual(show.name, expected_name)
+
+ template = "SHOW {} LIKE '%foo%'"
+ show = self.validate_identity(template.format(key), template.format(write_key))
+ self.assertIsInstance(show, exp.Show)
+ self.assertIsInstance(show.args["like"], exp.Literal)
+ self.assertEqual(show.text("like"), "%foo%")
+
+ template = "SHOW {} WHERE Column_name LIKE '%foo%'"
+ show = self.validate_identity(template.format(key), template.format(write_key))
+ self.assertIsInstance(show, exp.Show)
+ self.assertIsInstance(show.args["where"], exp.Where)
+ self.assertEqual(show.args["where"].sql(), "WHERE Column_name LIKE '%foo%'")
+
+ def test_show_columns(self):
+ show = self.validate_identity("SHOW COLUMNS FROM tbl_name")
+ self.assertIsInstance(show, exp.Show)
+ self.assertEqual(show.name, "COLUMNS")
+ self.assertEqual(show.text("target"), "tbl_name")
+ self.assertFalse(show.args["full"])
+
+ show = self.validate_identity("SHOW FULL COLUMNS FROM tbl_name FROM db_name LIKE '%foo%'")
+ self.assertIsInstance(show, exp.Show)
+ self.assertEqual(show.text("target"), "tbl_name")
+ self.assertTrue(show.args["full"])
+ self.assertEqual(show.text("db"), "db_name")
+ self.assertIsInstance(show.args["like"], exp.Literal)
+ self.assertEqual(show.text("like"), "%foo%")
+
+ def test_show_name(self):
+ for key in [
+ "CREATE DATABASE",
+ "CREATE EVENT",
+ "CREATE FUNCTION",
+ "CREATE PROCEDURE",
+ "CREATE TABLE",
+ "CREATE TRIGGER",
+ "CREATE VIEW",
+ "FUNCTION CODE",
+ "PROCEDURE CODE",
+ ]:
+ show = self.validate_identity(f"SHOW {key} foo")
+ self.assertIsInstance(show, exp.Show)
+ self.assertEqual(show.name, key)
+ self.assertEqual(show.text("target"), "foo")
+
+ def test_show_grants(self):
+ show = self.validate_identity(f"SHOW GRANTS FOR foo")
+ self.assertIsInstance(show, exp.Show)
+ self.assertEqual(show.name, "GRANTS")
+ self.assertEqual(show.text("target"), "foo")
+
+ def test_show_engine(self):
+ show = self.validate_identity("SHOW ENGINE foo STATUS")
+ self.assertIsInstance(show, exp.Show)
+ self.assertEqual(show.name, "ENGINE")
+ self.assertEqual(show.text("target"), "foo")
+ self.assertFalse(show.args["mutex"])
+
+ show = self.validate_identity("SHOW ENGINE foo MUTEX")
+ self.assertEqual(show.name, "ENGINE")
+ self.assertEqual(show.text("target"), "foo")
+ self.assertTrue(show.args["mutex"])
+
+ def test_show_errors(self):
+ for key in ["ERRORS", "WARNINGS"]:
+ show = self.validate_identity(f"SHOW {key}")
+ self.assertIsInstance(show, exp.Show)
+ self.assertEqual(show.name, key)
+
+ show = self.validate_identity(f"SHOW {key} LIMIT 2, 3")
+ self.assertEqual(show.text("limit"), "3")
+ self.assertEqual(show.text("offset"), "2")
+
+ def test_show_index(self):
+ show = self.validate_identity("SHOW INDEX FROM foo")
+ self.assertIsInstance(show, exp.Show)
+ self.assertEqual(show.name, "INDEX")
+ self.assertEqual(show.text("target"), "foo")
+
+ show = self.validate_identity("SHOW INDEX FROM foo FROM bar")
+ self.assertEqual(show.text("db"), "bar")
+
+ def test_show_db_like_or_where_sql(self):
+ for key in [
+ "OPEN TABLES",
+ "TABLE STATUS",
+ "TRIGGERS",
+ ]:
+ show = self.validate_identity(f"SHOW {key}")
+ self.assertIsInstance(show, exp.Show)
+ self.assertEqual(show.name, key)
+
+ show = self.validate_identity(f"SHOW {key} FROM db_name")
+ self.assertEqual(show.name, key)
+ self.assertEqual(show.text("db"), "db_name")
+
+ show = self.validate_identity(f"SHOW {key} LIKE '%foo%'")
+ self.assertEqual(show.name, key)
+ self.assertIsInstance(show.args["like"], exp.Literal)
+ self.assertEqual(show.text("like"), "%foo%")
+
+ show = self.validate_identity(f"SHOW {key} WHERE Column_name LIKE '%foo%'")
+ self.assertEqual(show.name, key)
+ self.assertIsInstance(show.args["where"], exp.Where)
+ self.assertEqual(show.args["where"].sql(), "WHERE Column_name LIKE '%foo%'")
+
+ def test_show_processlist(self):
+ show = self.validate_identity("SHOW PROCESSLIST")
+ self.assertIsInstance(show, exp.Show)
+ self.assertEqual(show.name, "PROCESSLIST")
+ self.assertFalse(show.args["full"])
+
+ show = self.validate_identity("SHOW FULL PROCESSLIST")
+ self.assertEqual(show.name, "PROCESSLIST")
+ self.assertTrue(show.args["full"])
+
+ def test_show_profile(self):
+ show = self.validate_identity("SHOW PROFILE")
+ self.assertIsInstance(show, exp.Show)
+ self.assertEqual(show.name, "PROFILE")
+
+ show = self.validate_identity("SHOW PROFILE BLOCK IO")
+ self.assertEqual(show.args["types"][0].name, "BLOCK IO")
+
+ show = self.validate_identity(
+ "SHOW PROFILE BLOCK IO, PAGE FAULTS FOR QUERY 1 OFFSET 2 LIMIT 3"
+ )
+ self.assertEqual(show.args["types"][0].name, "BLOCK IO")
+ self.assertEqual(show.args["types"][1].name, "PAGE FAULTS")
+ self.assertEqual(show.text("query"), "1")
+ self.assertEqual(show.text("offset"), "2")
+ self.assertEqual(show.text("limit"), "3")
+
+ def test_show_replica_status(self):
+ show = self.validate_identity("SHOW REPLICA STATUS")
+ self.assertIsInstance(show, exp.Show)
+ self.assertEqual(show.name, "REPLICA STATUS")
+
+ show = self.validate_identity("SHOW SLAVE STATUS", "SHOW REPLICA STATUS")
+ self.assertIsInstance(show, exp.Show)
+ self.assertEqual(show.name, "REPLICA STATUS")
+
+ show = self.validate_identity("SHOW REPLICA STATUS FOR CHANNEL channel_name")
+ self.assertEqual(show.text("channel"), "channel_name")
+
+ def test_show_tables(self):
+ show = self.validate_identity("SHOW TABLES")
+ self.assertIsInstance(show, exp.Show)
+ self.assertEqual(show.name, "TABLES")
+
+ show = self.validate_identity("SHOW FULL TABLES FROM db_name LIKE '%foo%'")
+ self.assertTrue(show.args["full"])
+ self.assertEqual(show.text("db"), "db_name")
+ self.assertIsInstance(show.args["like"], exp.Literal)
+ self.assertEqual(show.text("like"), "%foo%")
+
+ def test_set_variable(self):
+ cmd = self.parse_one("SET SESSION x = 1")
+ item = cmd.expressions[0]
+ self.assertEqual(item.text("kind"), "SESSION")
+ self.assertIsInstance(item.this, exp.EQ)
+ self.assertEqual(item.this.left.name, "x")
+ self.assertEqual(item.this.right.name, "1")
+
+ cmd = self.parse_one("SET @@GLOBAL.x = @@GLOBAL.y")
+ item = cmd.expressions[0]
+ self.assertEqual(item.text("kind"), "")
+ self.assertIsInstance(item.this, exp.EQ)
+ self.assertIsInstance(item.this.left, exp.SessionParameter)
+ self.assertIsInstance(item.this.right, exp.SessionParameter)
+
+ cmd = self.parse_one("SET NAMES 'charset_name' COLLATE 'collation_name'")
+ item = cmd.expressions[0]
+ self.assertEqual(item.text("kind"), "NAMES")
+ self.assertEqual(item.name, "charset_name")
+ self.assertEqual(item.text("collate"), "collation_name")
+
+ cmd = self.parse_one("SET CHARSET DEFAULT")
+ item = cmd.expressions[0]
+ self.assertEqual(item.text("kind"), "CHARACTER SET")
+ self.assertEqual(item.this.name, "DEFAULT")
+
+ cmd = self.parse_one("SET x = 1, y = 2")
+ self.assertEqual(len(cmd.expressions), 2)
diff --git a/tests/dialects/test_postgres.py b/tests/dialects/test_postgres.py
index 35141e2..8294eea 100644
--- a/tests/dialects/test_postgres.py
+++ b/tests/dialects/test_postgres.py
@@ -8,7 +8,9 @@ class TestPostgres(Validator):
def test_ddl(self):
self.validate_all(
"CREATE TABLE products (product_no INT UNIQUE, name TEXT, price DECIMAL)",
- write={"postgres": "CREATE TABLE products (product_no INT UNIQUE, name TEXT, price DECIMAL)"},
+ write={
+ "postgres": "CREATE TABLE products (product_no INT UNIQUE, name TEXT, price DECIMAL)"
+ },
)
self.validate_all(
"CREATE TABLE products (product_no INT CONSTRAINT must_be_different UNIQUE, name TEXT CONSTRAINT present NOT NULL, price DECIMAL)",
@@ -59,15 +61,27 @@ class TestPostgres(Validator):
def test_postgres(self):
self.validate_identity("SELECT CASE WHEN SUBSTRING('abcdefg') IN ('ab') THEN 1 ELSE 0 END")
- self.validate_identity("SELECT CASE WHEN SUBSTRING('abcdefg' FROM 1) IN ('ab') THEN 1 ELSE 0 END")
- self.validate_identity("SELECT CASE WHEN SUBSTRING('abcdefg' FROM 1 FOR 2) IN ('ab') THEN 1 ELSE 0 END")
- self.validate_identity('SELECT * FROM "x" WHERE SUBSTRING("x"."foo" FROM 1 FOR 2) IN (\'mas\')')
+ self.validate_identity(
+ "SELECT CASE WHEN SUBSTRING('abcdefg' FROM 1) IN ('ab') THEN 1 ELSE 0 END"
+ )
+ self.validate_identity(
+ "SELECT CASE WHEN SUBSTRING('abcdefg' FROM 1 FOR 2) IN ('ab') THEN 1 ELSE 0 END"
+ )
+ self.validate_identity(
+ 'SELECT * FROM "x" WHERE SUBSTRING("x"."foo" FROM 1 FOR 2) IN (\'mas\')'
+ )
self.validate_identity("SELECT * FROM x WHERE SUBSTRING('Thomas' FROM '...$') IN ('mas')")
- self.validate_identity("SELECT * FROM x WHERE SUBSTRING('Thomas' FROM '%#\"o_a#\"_' FOR '#') IN ('mas')")
- self.validate_identity("SELECT SUBSTRING('bla' + 'foo' || 'bar' FROM 3 - 1 + 5 FOR 4 + SOME_FUNC(arg1, arg2))")
+ self.validate_identity(
+ "SELECT * FROM x WHERE SUBSTRING('Thomas' FROM '%#\"o_a#\"_' FOR '#') IN ('mas')"
+ )
+ self.validate_identity(
+ "SELECT SUBSTRING('bla' + 'foo' || 'bar' FROM 3 - 1 + 5 FOR 4 + SOME_FUNC(arg1, arg2))"
+ )
self.validate_identity("SELECT TRIM(' X' FROM ' XXX ')")
self.validate_identity("SELECT TRIM(LEADING 'bla' FROM ' XXX ' COLLATE utf8_bin)")
- self.validate_identity("SELECT TO_TIMESTAMP(1284352323.5), TO_TIMESTAMP('05 Dec 2000', 'DD Mon YYYY')")
+ self.validate_identity(
+ "SELECT TO_TIMESTAMP(1284352323.5), TO_TIMESTAMP('05 Dec 2000', 'DD Mon YYYY')"
+ )
self.validate_identity("COMMENT ON TABLE mytable IS 'this'")
self.validate_identity("SELECT e'\\xDEADBEEF'")
self.validate_identity("SELECT CAST(e'\\176' AS BYTEA)")
@@ -75,7 +89,7 @@ class TestPostgres(Validator):
self.validate_all(
"CREATE TABLE x (a UUID, b BYTEA)",
write={
- "duckdb": "CREATE TABLE x (a UUID, b BINARY)",
+ "duckdb": "CREATE TABLE x (a UUID, b VARBINARY)",
"presto": "CREATE TABLE x (a UUID, b VARBINARY)",
"hive": "CREATE TABLE x (a UUID, b BINARY)",
"spark": "CREATE TABLE x (a UUID, b BINARY)",
@@ -153,7 +167,9 @@ class TestPostgres(Validator):
)
self.validate_all(
"SELECT * FROM foo, LATERAL (SELECT * FROM bar WHERE bar.id = foo.bar_id) AS ss",
- read={"postgres": "SELECT * FROM foo, LATERAL (SELECT * FROM bar WHERE bar.id = foo.bar_id) AS ss"},
+ read={
+ "postgres": "SELECT * FROM foo, LATERAL (SELECT * FROM bar WHERE bar.id = foo.bar_id) AS ss"
+ },
)
self.validate_all(
"SELECT m.name FROM manufacturers AS m LEFT JOIN LATERAL GET_PRODUCT_NAMES(m.id) AS pname ON TRUE WHERE pname IS NULL",
@@ -169,11 +185,15 @@ class TestPostgres(Validator):
)
self.validate_all(
"SELECT id, email, CAST(deleted AS TEXT) FROM users WHERE NOT deleted IS NULL",
- read={"postgres": "SELECT id, email, CAST(deleted AS TEXT) FROM users WHERE deleted NOTNULL"},
+ read={
+ "postgres": "SELECT id, email, CAST(deleted AS TEXT) FROM users WHERE deleted NOTNULL"
+ },
)
self.validate_all(
"SELECT id, email, CAST(deleted AS TEXT) FROM users WHERE NOT deleted IS NULL",
- read={"postgres": "SELECT id, email, CAST(deleted AS TEXT) FROM users WHERE NOT deleted ISNULL"},
+ read={
+ "postgres": "SELECT id, email, CAST(deleted AS TEXT) FROM users WHERE NOT deleted ISNULL"
+ },
)
self.validate_all(
"'[1,2,3]'::json->2",
@@ -184,7 +204,8 @@ class TestPostgres(Validator):
write={"postgres": """CAST('{"a":1,"b":2}' AS JSON)->'b'"""},
)
self.validate_all(
- """'{"x": {"y": 1}}'::json->'x'->'y'""", write={"postgres": """CAST('{"x": {"y": 1}}' AS JSON)->'x'->'y'"""}
+ """'{"x": {"y": 1}}'::json->'x'->'y'""",
+ write={"postgres": """CAST('{"x": {"y": 1}}' AS JSON)->'x'->'y'"""},
)
self.validate_all(
"""'{"x": {"y": 1}}'::json->'x'::json->'y'""",
diff --git a/tests/dialects/test_redshift.py b/tests/dialects/test_redshift.py
index 1ed2bb6..5309a34 100644
--- a/tests/dialects/test_redshift.py
+++ b/tests/dialects/test_redshift.py
@@ -61,4 +61,6 @@ class TestRedshift(Validator):
"SELECT caldate + INTERVAL '1 second' AS dateplus FROM date WHERE caldate = '12-31-2008'"
)
self.validate_identity("CREATE TABLE datetable (start_date DATE, end_date DATE)")
- self.validate_identity("SELECT COUNT(*) FROM event WHERE eventname LIKE '%Ring%' OR eventname LIKE '%Die%'")
+ self.validate_identity(
+ "SELECT COUNT(*) FROM event WHERE eventname LIKE '%Ring%' OR eventname LIKE '%Die%'"
+ )
diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py
index fea2311..1846b17 100644
--- a/tests/dialects/test_snowflake.py
+++ b/tests/dialects/test_snowflake.py
@@ -336,7 +336,8 @@ class TestSnowflake(Validator):
def test_table_literal(self):
# All examples from https://docs.snowflake.com/en/sql-reference/literals-table.html
self.validate_all(
- r"""SELECT * FROM TABLE('MYTABLE')""", write={"snowflake": r"""SELECT * FROM TABLE('MYTABLE')"""}
+ r"""SELECT * FROM TABLE('MYTABLE')""",
+ write={"snowflake": r"""SELECT * FROM TABLE('MYTABLE')"""},
)
self.validate_all(
@@ -352,15 +353,123 @@ class TestSnowflake(Validator):
write={"snowflake": r"""SELECT * FROM TABLE('MYDB. "MYSCHEMA"."MYTABLE"')"""},
)
- self.validate_all(r"""SELECT * FROM TABLE($MYVAR)""", write={"snowflake": r"""SELECT * FROM TABLE($MYVAR)"""})
+ self.validate_all(
+ r"""SELECT * FROM TABLE($MYVAR)""",
+ write={"snowflake": r"""SELECT * FROM TABLE($MYVAR)"""},
+ )
- self.validate_all(r"""SELECT * FROM TABLE(?)""", write={"snowflake": r"""SELECT * FROM TABLE(?)"""})
+ self.validate_all(
+ r"""SELECT * FROM TABLE(?)""", write={"snowflake": r"""SELECT * FROM TABLE(?)"""}
+ )
self.validate_all(
- r"""SELECT * FROM TABLE(:BINDING)""", write={"snowflake": r"""SELECT * FROM TABLE(:BINDING)"""}
+ r"""SELECT * FROM TABLE(:BINDING)""",
+ write={"snowflake": r"""SELECT * FROM TABLE(:BINDING)"""},
)
self.validate_all(
r"""SELECT * FROM TABLE($MYVAR) WHERE COL1 = 10""",
write={"snowflake": r"""SELECT * FROM TABLE($MYVAR) WHERE COL1 = 10"""},
)
+
+ def test_flatten(self):
+ self.validate_all(
+ """
+ select
+ dag_report.acct_id,
+ dag_report.report_date,
+ dag_report.report_uuid,
+ dag_report.airflow_name,
+ dag_report.dag_id,
+ f.value::varchar as operator
+ from cs.telescope.dag_report,
+ table(flatten(input=>split(operators, ','))) f
+ """,
+ write={
+ "snowflake": """SELECT
+ dag_report.acct_id,
+ dag_report.report_date,
+ dag_report.report_uuid,
+ dag_report.airflow_name,
+ dag_report.dag_id,
+ CAST(f.value AS VARCHAR) AS operator
+FROM cs.telescope.dag_report, TABLE(FLATTEN(input => SPLIT(operators, ','))) AS f"""
+ },
+ pretty=True,
+ )
+
+ # All examples from https://docs.snowflake.com/en/sql-reference/functions/flatten.html#syntax
+ self.validate_all(
+ "SELECT * FROM TABLE(FLATTEN(input => parse_json('[1, ,77]'))) f",
+ write={
+ "snowflake": "SELECT * FROM TABLE(FLATTEN(input => PARSE_JSON('[1, ,77]'))) AS f"
+ },
+ )
+
+ self.validate_all(
+ """SELECT * FROM TABLE(FLATTEN(input => parse_json('{"a":1, "b":[77,88]}'), outer => true)) f""",
+ write={
+ "snowflake": """SELECT * FROM TABLE(FLATTEN(input => PARSE_JSON('{"a":1, "b":[77,88]}'), outer => TRUE)) AS f"""
+ },
+ )
+
+ self.validate_all(
+ """SELECT * FROM TABLE(FLATTEN(input => parse_json('{"a":1, "b":[77,88]}'), path => 'b')) f""",
+ write={
+ "snowflake": """SELECT * FROM TABLE(FLATTEN(input => PARSE_JSON('{"a":1, "b":[77,88]}'), path => 'b')) AS f"""
+ },
+ )
+
+ self.validate_all(
+ """SELECT * FROM TABLE(FLATTEN(input => parse_json('[]'))) f""",
+ write={"snowflake": """SELECT * FROM TABLE(FLATTEN(input => PARSE_JSON('[]'))) AS f"""},
+ )
+
+ self.validate_all(
+ """SELECT * FROM TABLE(FLATTEN(input => parse_json('[]'), outer => true)) f""",
+ write={
+ "snowflake": """SELECT * FROM TABLE(FLATTEN(input => PARSE_JSON('[]'), outer => TRUE)) AS f"""
+ },
+ )
+
+ self.validate_all(
+ """SELECT * FROM TABLE(FLATTEN(input => parse_json('{"a":1, "b":[77,88], "c": {"d":"X"}}'))) f""",
+ write={
+ "snowflake": """SELECT * FROM TABLE(FLATTEN(input => PARSE_JSON('{"a":1, "b":[77,88], "c": {"d":"X"}}'))) AS f"""
+ },
+ )
+
+ self.validate_all(
+ """SELECT * FROM TABLE(FLATTEN(input => parse_json('{"a":1, "b":[77,88], "c": {"d":"X"}}'), recursive => true)) f""",
+ write={
+ "snowflake": """SELECT * FROM TABLE(FLATTEN(input => PARSE_JSON('{"a":1, "b":[77,88], "c": {"d":"X"}}'), recursive => TRUE)) AS f"""
+ },
+ )
+
+ self.validate_all(
+ """SELECT * FROM TABLE(FLATTEN(input => parse_json('{"a":1, "b":[77,88], "c": {"d":"X"}}'), recursive => true, mode => 'object')) f""",
+ write={
+ "snowflake": """SELECT * FROM TABLE(FLATTEN(input => PARSE_JSON('{"a":1, "b":[77,88], "c": {"d":"X"}}'), recursive => TRUE, mode => 'object')) AS f"""
+ },
+ )
+
+ self.validate_all(
+ """
+ SELECT id as "ID",
+ f.value AS "Contact",
+ f1.value:type AS "Type",
+ f1.value:content AS "Details"
+ FROM persons p,
+ lateral flatten(input => p.c, path => 'contact') f,
+ lateral flatten(input => f.value:business) f1
+ """,
+ write={
+ "snowflake": """SELECT
+ id AS "ID",
+ f.value AS "Contact",
+ f1.value['type'] AS "Type",
+ f1.value['content'] AS "Details"
+FROM persons AS p, LATERAL FLATTEN(input => p.c, path => 'contact') f, LATERAL FLATTEN(input => f.value['business']) f1""",
+ },
+ pretty=True,
+ )
diff --git a/tests/dialects/test_spark.py b/tests/dialects/test_spark.py
index 8605bd1..4470722 100644
--- a/tests/dialects/test_spark.py
+++ b/tests/dialects/test_spark.py
@@ -284,4 +284,6 @@ TBLPROPERTIES (
)
def test_iif(self):
- self.validate_all("SELECT IIF(cond, 'True', 'False')", write={"spark": "SELECT IF(cond, 'True', 'False')"})
+ self.validate_all(
+ "SELECT IIF(cond, 'True', 'False')", write={"spark": "SELECT IF(cond, 'True', 'False')"}
+ )
diff --git a/tests/dialects/test_starrocks.py b/tests/dialects/test_starrocks.py
index 1fe1a57..35d8b45 100644
--- a/tests/dialects/test_starrocks.py
+++ b/tests/dialects/test_starrocks.py
@@ -6,3 +6,6 @@ class TestMySQL(Validator):
def test_identity(self):
self.validate_identity("SELECT CAST(`a`.`b` AS INT) FROM foo")
+
+ def test_time(self):
+ self.validate_identity("TIMESTAMP('2022-01-01')")
diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py
index d22a9c2..a60f48d 100644
--- a/tests/dialects/test_tsql.py
+++ b/tests/dialects/test_tsql.py
@@ -278,12 +278,19 @@ class TestTSQL(Validator):
def test_add_date(self):
self.validate_identity("SELECT DATEADD(year, 1, '2017/08/25')")
self.validate_all(
- "SELECT DATEADD(year, 1, '2017/08/25')", write={"spark": "SELECT ADD_MONTHS('2017/08/25', 12)"}
+ "SELECT DATEADD(year, 1, '2017/08/25')",
+ write={"spark": "SELECT ADD_MONTHS('2017/08/25', 12)"},
+ )
+ self.validate_all(
+ "SELECT DATEADD(qq, 1, '2017/08/25')",
+ write={"spark": "SELECT ADD_MONTHS('2017/08/25', 3)"},
)
- self.validate_all("SELECT DATEADD(qq, 1, '2017/08/25')", write={"spark": "SELECT ADD_MONTHS('2017/08/25', 3)"})
self.validate_all(
"SELECT DATEADD(wk, 1, '2017/08/25')",
- write={"spark": "SELECT DATE_ADD('2017/08/25', 7)", "databricks": "SELECT DATEADD(week, 1, '2017/08/25')"},
+ write={
+ "spark": "SELECT DATE_ADD('2017/08/25', 7)",
+ "databricks": "SELECT DATEADD(week, 1, '2017/08/25')",
+ },
)
def test_date_diff(self):
@@ -370,13 +377,21 @@ class TestTSQL(Validator):
"SELECT FORMAT(1000000.01,'###,###.###')",
write={"spark": "SELECT FORMAT_NUMBER(1000000.01, '###,###.###')"},
)
- self.validate_all("SELECT FORMAT(1234567, 'f')", write={"spark": "SELECT FORMAT_NUMBER(1234567, 'f')"})
+ self.validate_all(
+ "SELECT FORMAT(1234567, 'f')", write={"spark": "SELECT FORMAT_NUMBER(1234567, 'f')"}
+ )
self.validate_all(
"SELECT FORMAT('01-01-1991', 'dd.mm.yyyy')",
write={"spark": "SELECT DATE_FORMAT('01-01-1991', 'dd.mm.yyyy')"},
)
self.validate_all(
- "SELECT FORMAT(date_col, 'dd.mm.yyyy')", write={"spark": "SELECT DATE_FORMAT(date_col, 'dd.mm.yyyy')"}
+ "SELECT FORMAT(date_col, 'dd.mm.yyyy')",
+ write={"spark": "SELECT DATE_FORMAT(date_col, 'dd.mm.yyyy')"},
+ )
+ self.validate_all(
+ "SELECT FORMAT(date_col, 'm')",
+ write={"spark": "SELECT DATE_FORMAT(date_col, 'MMMM d')"},
+ )
+ self.validate_all(
+ "SELECT FORMAT(num_col, 'c')", write={"spark": "SELECT FORMAT_NUMBER(num_col, 'c')"}
)
- self.validate_all("SELECT FORMAT(date_col, 'm')", write={"spark": "SELECT DATE_FORMAT(date_col, 'MMMM d')"})
- self.validate_all("SELECT FORMAT(num_col, 'c')", write={"spark": "SELECT FORMAT_NUMBER(num_col, 'c')"})
diff --git a/tests/fixtures/identity.sql b/tests/fixtures/identity.sql
index d7084ac..836ab28 100644
--- a/tests/fixtures/identity.sql
+++ b/tests/fixtures/identity.sql
@@ -523,6 +523,8 @@ DROP VIEW a.b
DROP VIEW IF EXISTS a
DROP VIEW IF EXISTS a.b
SHOW TABLES
+USE db
+ROLLBACK
EXPLAIN SELECT * FROM x
INSERT INTO x SELECT * FROM y
INSERT INTO x (SELECT * FROM y)
@@ -569,3 +571,13 @@ SELECT * FROM (tbl1 LEFT JOIN tbl2 ON 1 = 1)
SELECT * FROM (tbl1 JOIN tbl2 JOIN tbl3)
SELECT * FROM (tbl1 JOIN (tbl2 JOIN tbl3) ON bla = foo)
SELECT * FROM (tbl1 JOIN LATERAL (SELECT * FROM bla) AS tbl)
+SELECT CAST(x AS INT) /* comment */ FROM foo
+SELECT a /* x */, b /* x */
+SELECT * FROM foo /* x */, bla /* x */
+SELECT 1 /* comment */ + 1
+SELECT 1 /* c1 */ + 2 /* c2 */
+SELECT 1 /* c1 */ + 2 /* c2 */ + 3 /* c3 */
+SELECT 1 /* c1 */ + 2 /* c2 */, 3 /* c3 */
+SELECT x FROM a.b.c /* x */, e.f.g /* x */
+SELECT FOO(x /* c */) /* FOO */, b /* b */
+SELECT FOO(x /* c1 */ + y /* c2 */ + BLA(5 /* c3 */)) FROM VALUES (1 /* c4 */, "test" /* c5 */) /* c6 */
diff --git a/tests/fixtures/optimizer/qualify_columns.sql b/tests/fixtures/optimizer/qualify_columns.sql
index a958c08..1176078 100644
--- a/tests/fixtures/optimizer/qualify_columns.sql
+++ b/tests/fixtures/optimizer/qualify_columns.sql
@@ -104,6 +104,16 @@ SELECT x.b AS b, x.a AS a FROM x AS x LEFT JOIN y AS y ON x.b = y.b QUALIFY ROW_
SELECT AGGREGATE(ARRAY(a, x.b), 0, (x, acc) -> x + acc + a) AS sum_agg FROM x;
SELECT AGGREGATE(ARRAY(x.a, x.b), 0, (x, acc) -> x + acc + x.a) AS sum_agg FROM x AS x;
+# dialect: starrocks
+# execute: false
+SELECT DATE_TRUNC('week', a) AS a FROM x;
+SELECT DATE_TRUNC('week', x.a) AS a FROM x AS x;
+
+# dialect: bigquery
+# execute: false
+SELECT DATE_TRUNC(a, MONTH) AS a FROM x;
+SELECT DATE_TRUNC(x.a, MONTH) AS a FROM x AS x;
+
--------------------------------------
-- Derived tables
--------------------------------------
diff --git a/tests/fixtures/optimizer/simplify.sql b/tests/fixtures/optimizer/simplify.sql
index 07e818f..7207ba2 100644
--- a/tests/fixtures/optimizer/simplify.sql
+++ b/tests/fixtures/optimizer/simplify.sql
@@ -79,6 +79,15 @@ NULL;
NULL = NULL;
NULL;
+NULL <=> NULL;
+TRUE;
+
+a IS NOT DISTINCT FROM a;
+TRUE;
+
+NULL IS DISTINCT FROM NULL;
+FALSE;
+
NOT (NOT TRUE);
TRUE;
diff --git a/tests/fixtures/pretty.sql b/tests/fixtures/pretty.sql
index 2570650..5e27b5e 100644
--- a/tests/fixtures/pretty.sql
+++ b/tests/fixtures/pretty.sql
@@ -287,3 +287,31 @@ SELECT
"fffffff"
)
);
+/*
+ multi
+ line
+ comment
+*/
+SELECT * FROM foo;
+/*
+ multi
+ line
+ comment
+*/
+SELECT
+ *
+FROM foo;
+SELECT x FROM a.b.c /*x*/, e.f.g /*x*/;
+SELECT
+ x
+FROM a.b.c /* x */, e.f.g /* x */;
+SELECT x FROM (SELECT * FROM bla /*x*/WHERE id = 1) /*x*/;
+SELECT
+ x
+FROM (
+ SELECT
+ *
+ FROM bla /* x */
+ WHERE
+ id = 1
+) /* x */;
diff --git a/tests/test_build.py b/tests/test_build.py
index b7b6865..721c868 100644
--- a/tests/test_build.py
+++ b/tests/test_build.py
@@ -100,15 +100,21 @@ class TestBuild(unittest.TestCase):
"SELECT x FROM tbl LEFT OUTER JOIN tbl2",
),
(
- lambda: select("x").from_("tbl").join(exp.Table(this="tbl2"), join_type="left outer"),
+ lambda: select("x")
+ .from_("tbl")
+ .join(exp.Table(this="tbl2"), join_type="left outer"),
"SELECT x FROM tbl LEFT OUTER JOIN tbl2",
),
(
- lambda: select("x").from_("tbl").join(exp.Table(this="tbl2"), join_type="left outer", join_alias="foo"),
+ lambda: select("x")
+ .from_("tbl")
+ .join(exp.Table(this="tbl2"), join_type="left outer", join_alias="foo"),
"SELECT x FROM tbl LEFT OUTER JOIN tbl2 AS foo",
),
(
- lambda: select("x").from_("tbl").join(select("y").from_("tbl2"), join_type="left outer"),
+ lambda: select("x")
+ .from_("tbl")
+ .join(select("y").from_("tbl2"), join_type="left outer"),
"SELECT x FROM tbl LEFT OUTER JOIN (SELECT y FROM tbl2)",
),
(
@@ -131,7 +137,9 @@ class TestBuild(unittest.TestCase):
"SELECT x FROM tbl LEFT OUTER JOIN (SELECT y FROM tbl2) AS aliased",
),
(
- lambda: select("x").from_("tbl").join(parse_one("left join x", into=exp.Join), on="a=b"),
+ lambda: select("x")
+ .from_("tbl")
+ .join(parse_one("left join x", into=exp.Join), on="a=b"),
"SELECT x FROM tbl LEFT JOIN x ON a = b",
),
(
@@ -139,7 +147,9 @@ class TestBuild(unittest.TestCase):
"SELECT x FROM tbl LEFT JOIN x ON a = b",
),
(
- lambda: select("x").from_("tbl").join("select b from tbl2", on="a=b", join_type="left"),
+ lambda: select("x")
+ .from_("tbl")
+ .join("select b from tbl2", on="a=b", join_type="left"),
"SELECT x FROM tbl LEFT JOIN (SELECT b FROM tbl2) ON a = b",
),
(
@@ -162,7 +172,10 @@ class TestBuild(unittest.TestCase):
(
lambda: select("x", "y", "z")
.from_("merged_df")
- .join("vte_diagnosis_df", using=[exp.to_identifier("patient_id"), exp.to_identifier("encounter_id")]),
+ .join(
+ "vte_diagnosis_df",
+ using=[exp.to_identifier("patient_id"), exp.to_identifier("encounter_id")],
+ ),
"SELECT x, y, z FROM merged_df JOIN vte_diagnosis_df USING (patient_id, encounter_id)",
),
(
@@ -222,7 +235,10 @@ class TestBuild(unittest.TestCase):
"SELECT x, y, z, a FROM tbl ORDER BY x, y, z, a",
),
(
- lambda: select("x", "y", "z", "a").from_("tbl").cluster_by("x, y", "z").cluster_by("a"),
+ lambda: select("x", "y", "z", "a")
+ .from_("tbl")
+ .cluster_by("x, y", "z")
+ .cluster_by("a"),
"SELECT x, y, z, a FROM tbl CLUSTER BY x, y, z, a",
),
(
@@ -239,7 +255,9 @@ class TestBuild(unittest.TestCase):
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl",
),
(
- lambda: select("x").from_("tbl").with_("tbl", as_="SELECT x FROM tbl2", recursive=True),
+ lambda: select("x")
+ .from_("tbl")
+ .with_("tbl", as_="SELECT x FROM tbl2", recursive=True),
"WITH RECURSIVE tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl",
),
(
@@ -247,7 +265,9 @@ class TestBuild(unittest.TestCase):
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl",
),
(
- lambda: select("x").from_("tbl").with_("tbl (x, y)", as_=select("x", "y").from_("tbl2")),
+ lambda: select("x")
+ .from_("tbl")
+ .with_("tbl (x, y)", as_=select("x", "y").from_("tbl2")),
"WITH tbl(x, y) AS (SELECT x, y FROM tbl2) SELECT x FROM tbl",
),
(
@@ -258,7 +278,10 @@ class TestBuild(unittest.TestCase):
"WITH tbl AS (SELECT x FROM tbl2), tbl2 AS (SELECT x FROM tbl3) SELECT x FROM tbl",
),
(
- lambda: select("x").from_("tbl").with_("tbl", as_=select("x", "y").from_("tbl2")).select("y"),
+ lambda: select("x")
+ .from_("tbl")
+ .with_("tbl", as_=select("x", "y").from_("tbl2"))
+ .select("y"),
"WITH tbl AS (SELECT x, y FROM tbl2) SELECT x, y FROM tbl",
),
(
@@ -266,35 +289,59 @@ class TestBuild(unittest.TestCase):
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl",
),
(
- lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").group_by("x"),
+ lambda: select("x")
+ .with_("tbl", as_=select("x").from_("tbl2"))
+ .from_("tbl")
+ .group_by("x"),
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl GROUP BY x",
),
(
- lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").order_by("x"),
+ lambda: select("x")
+ .with_("tbl", as_=select("x").from_("tbl2"))
+ .from_("tbl")
+ .order_by("x"),
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl ORDER BY x",
),
(
- lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").limit(10),
+ lambda: select("x")
+ .with_("tbl", as_=select("x").from_("tbl2"))
+ .from_("tbl")
+ .limit(10),
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl LIMIT 10",
),
(
- lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").offset(10),
+ lambda: select("x")
+ .with_("tbl", as_=select("x").from_("tbl2"))
+ .from_("tbl")
+ .offset(10),
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl OFFSET 10",
),
(
- lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").join("tbl3"),
+ lambda: select("x")
+ .with_("tbl", as_=select("x").from_("tbl2"))
+ .from_("tbl")
+ .join("tbl3"),
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl JOIN tbl3",
),
(
- lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").distinct(),
+ lambda: select("x")
+ .with_("tbl", as_=select("x").from_("tbl2"))
+ .from_("tbl")
+ .distinct(),
"WITH tbl AS (SELECT x FROM tbl2) SELECT DISTINCT x FROM tbl",
),
(
- lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").where("x > 10"),
+ lambda: select("x")
+ .with_("tbl", as_=select("x").from_("tbl2"))
+ .from_("tbl")
+ .where("x > 10"),
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl WHERE x > 10",
),
(
- lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").having("x > 20"),
+ lambda: select("x")
+ .with_("tbl", as_=select("x").from_("tbl2"))
+ .from_("tbl")
+ .having("x > 20"),
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl HAVING x > 20",
),
(lambda: select("x").from_("tbl").subquery(), "(SELECT x FROM tbl)"),
@@ -354,7 +401,9 @@ class TestBuild(unittest.TestCase):
"SELECT x FROM (SELECT x FROM tbl) AS foo WHERE x > 0",
),
(
- lambda: exp.subquery("select x from tbl UNION select x from bar", "unioned").select("x"),
+ lambda: exp.subquery("select x from tbl UNION select x from bar", "unioned").select(
+ "x"
+ ),
"SELECT x FROM (SELECT x FROM tbl UNION SELECT x FROM bar) AS unioned",
),
(
diff --git a/tests/test_executor.py b/tests/test_executor.py
index ef1a706..49805b9 100644
--- a/tests/test_executor.py
+++ b/tests/test_executor.py
@@ -33,7 +33,10 @@ class TestExecutor(unittest.TestCase):
)
cls.cache = {}
- cls.sqls = [(sql, expected) for _, sql, expected in load_sql_fixture_pairs("optimizer/tpc-h/tpc-h.sql")]
+ cls.sqls = [
+ (sql, expected)
+ for _, sql, expected in load_sql_fixture_pairs("optimizer/tpc-h/tpc-h.sql")
+ ]
@classmethod
def tearDownClass(cls):
@@ -63,7 +66,9 @@ class TestExecutor(unittest.TestCase):
def test_execute_tpch(self):
def to_csv(expression):
if isinstance(expression, exp.Table):
- return parse_one(f"READ_CSV('{DIR}{expression.name}.csv.gz', 'delimiter', '|') AS {expression.name}")
+ return parse_one(
+ f"READ_CSV('{DIR}{expression.name}.csv.gz', 'delimiter', '|') AS {expression.name}"
+ )
return expression
for sql, _ in self.sqls[0:3]:
diff --git a/tests/test_expressions.py b/tests/test_expressions.py
index adfd329..63371d8 100644
--- a/tests/test_expressions.py
+++ b/tests/test_expressions.py
@@ -30,7 +30,9 @@ class TestExpressions(unittest.TestCase):
self.assertEqual(parse_one("TO_DATE(x)", read="hive"), parse_one("ts_or_ds_to_date(x)"))
self.assertEqual(exp.Table(pivots=[]), exp.Table())
self.assertNotEqual(exp.Table(pivots=[None]), exp.Table())
- self.assertEqual(exp.DataType.build("int"), exp.DataType(this=exp.DataType.Type.INT, nested=False))
+ self.assertEqual(
+ exp.DataType.build("int"), exp.DataType(this=exp.DataType.Type.INT, nested=False)
+ )
def test_find(self):
expression = parse_one("CREATE TABLE x STORED AS PARQUET AS SELECT * FROM y")
@@ -89,7 +91,9 @@ class TestExpressions(unittest.TestCase):
self.assertIsNone(column.find_ancestor(exp.Join))
def test_alias_or_name(self):
- expression = parse_one("SELECT a, b AS B, c + d AS e, *, 'zz', 'zz' AS z FROM foo as bar, baz")
+ expression = parse_one(
+ "SELECT a, b AS B, c + d AS e, *, 'zz', 'zz' AS z FROM foo as bar, baz"
+ )
self.assertEqual(
[e.alias_or_name for e in expression.expressions],
["a", "B", "e", "*", "zz", "z"],
@@ -166,7 +170,9 @@ class TestExpressions(unittest.TestCase):
"SELECT * FROM foo WHERE ? > 100",
)
self.assertEqual(
- exp.replace_placeholders(parse_one("select * from :name WHERE ? > 100"), another_name="bla").sql(),
+ exp.replace_placeholders(
+ parse_one("select * from :name WHERE ? > 100"), another_name="bla"
+ ).sql(),
"SELECT * FROM :name WHERE ? > 100",
)
self.assertEqual(
@@ -183,7 +189,9 @@ class TestExpressions(unittest.TestCase):
)
def test_named_selects(self):
- expression = parse_one("SELECT a, b AS B, c + d AS e, *, 'zz', 'zz' AS z FROM foo as bar, baz")
+ expression = parse_one(
+ "SELECT a, b AS B, c + d AS e, *, 'zz', 'zz' AS z FROM foo as bar, baz"
+ )
self.assertEqual(expression.named_selects, ["a", "B", "e", "*", "zz", "z"])
expression = parse_one(
@@ -367,7 +375,9 @@ class TestExpressions(unittest.TestCase):
self.assertEqual(len(list(expression.walk())), 9)
self.assertEqual(len(list(expression.walk(bfs=False))), 9)
self.assertTrue(all(isinstance(e, exp.Expression) for e, _, _ in expression.walk()))
- self.assertTrue(all(isinstance(e, exp.Expression) for e, _, _ in expression.walk(bfs=False)))
+ self.assertTrue(
+ all(isinstance(e, exp.Expression) for e, _, _ in expression.walk(bfs=False))
+ )
def test_functions(self):
self.assertIsInstance(parse_one("ABS(a)"), exp.Abs)
@@ -512,14 +522,21 @@ class TestExpressions(unittest.TestCase):
),
exp.Properties(
expressions=[
- exp.FileFormatProperty(this=exp.Literal.string("FORMAT"), value=exp.Literal.string("parquet")),
+ exp.FileFormatProperty(
+ this=exp.Literal.string("FORMAT"), value=exp.Literal.string("parquet")
+ ),
exp.PartitionedByProperty(
this=exp.Literal.string("PARTITIONED_BY"),
- value=exp.Tuple(expressions=[exp.to_identifier("a"), exp.to_identifier("b")]),
+ value=exp.Tuple(
+ expressions=[exp.to_identifier("a"), exp.to_identifier("b")]
+ ),
+ ),
+ exp.AnonymousProperty(
+ this=exp.Literal.string("custom"), value=exp.Literal.number(1)
),
- exp.AnonymousProperty(this=exp.Literal.string("custom"), value=exp.Literal.number(1)),
exp.TableFormatProperty(
- this=exp.Literal.string("TABLE_FORMAT"), value=exp.to_identifier("test_format")
+ this=exp.Literal.string("TABLE_FORMAT"),
+ value=exp.to_identifier("test_format"),
),
exp.EngineProperty(this=exp.Literal.string("ENGINE"), value=exp.NULL),
exp.CollateProperty(this=exp.Literal.string("COLLATE"), value=exp.TRUE),
@@ -538,7 +555,10 @@ class TestExpressions(unittest.TestCase):
((1, "2", None), "(1, '2', NULL)"),
([1, "2", None], "ARRAY(1, '2', NULL)"),
({"x": None}, "MAP('x', NULL)"),
- (datetime.datetime(2022, 10, 1, 1, 1, 1), "TIME_STR_TO_TIME('2022-10-01 01:01:01.000000')"),
+ (
+ datetime.datetime(2022, 10, 1, 1, 1, 1),
+ "TIME_STR_TO_TIME('2022-10-01 01:01:01.000000')",
+ ),
(
datetime.datetime(2022, 10, 1, 1, 1, 1, tzinfo=datetime.timezone.utc),
"TIME_STR_TO_TIME('2022-10-01 01:01:01.000000+0000')",
@@ -548,30 +568,48 @@ class TestExpressions(unittest.TestCase):
with self.subTest(value):
self.assertEqual(exp.convert(value).sql(), expected)
- def test_annotation_alias(self):
- sql = "SELECT a, b AS B, c # comment, d AS D # another_comment FROM foo"
+ def test_comment_alias(self):
+ sql = """
+ SELECT
+ a,
+ b AS B,
+ c, /*comment*/
+ d AS D, -- another comment
+ CAST(x AS INT) -- final comment
+ FROM foo
+ """
expression = parse_one(sql)
self.assertEqual(
[e.alias_or_name for e in expression.expressions],
- ["a", "B", "c", "D"],
+ ["a", "B", "c", "D", "x"],
+ )
+ self.assertEqual(
+ expression.sql(),
+ "SELECT a, b AS B, c /* comment */, d AS D /* another comment */, CAST(x AS INT) /* final comment */ FROM foo",
+ )
+ self.assertEqual(
+ expression.sql(comments=False),
+ "SELECT a, b AS B, c, d AS D, CAST(x AS INT) FROM foo",
)
- self.assertEqual(expression.sql(), "SELECT a, b AS B, c, d AS D")
- self.assertEqual(expression.expressions[2].name, "comment")
self.assertEqual(
- expression.sql(pretty=True, annotations=False),
+ expression.sql(pretty=True, comments=False),
"""SELECT
a,
b AS B,
c,
- d AS D""",
+ d AS D,
+ CAST(x AS INT)
+FROM foo""",
)
self.assertEqual(
expression.sql(pretty=True),
"""SELECT
a,
b AS B,
- c # comment,
- d AS D # another_comment FROM foo""",
+ c, -- comment
+ d AS D, -- another comment
+ CAST(x AS INT) -- final comment
+FROM foo""",
)
def test_to_table(self):
@@ -605,5 +643,9 @@ class TestExpressions(unittest.TestCase):
self.assertIsInstance(expression, exp.Union)
self.assertEqual(expression.named_selects, ["cola", "colb"])
self.assertEqual(
- expression.selects, [exp.Column(this=exp.to_identifier("cola")), exp.Column(this=exp.to_identifier("colb"))]
+ expression.selects,
+ [
+ exp.Column(this=exp.to_identifier("cola")),
+ exp.Column(this=exp.to_identifier("colb")),
+ ],
)
diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py
index 3b5990f..a1b7e70 100644
--- a/tests/test_optimizer.py
+++ b/tests/test_optimizer.py
@@ -67,7 +67,9 @@ class TestOptimizer(unittest.TestCase):
}
def check_file(self, file, func, pretty=False, execute=False, **kwargs):
- for i, (meta, sql, expected) in enumerate(load_sql_fixture_pairs(f"optimizer/{file}.sql"), start=1):
+ for i, (meta, sql, expected) in enumerate(
+ load_sql_fixture_pairs(f"optimizer/{file}.sql"), start=1
+ ):
title = meta.get("title") or f"{i}, {sql}"
dialect = meta.get("dialect")
leave_tables_isolated = meta.get("leave_tables_isolated")
@@ -90,7 +92,9 @@ class TestOptimizer(unittest.TestCase):
if string_to_bool(should_execute):
with self.subTest(f"(execute) {title}"):
- df1 = self.conn.execute(sqlglot.transpile(sql, read=dialect, write="duckdb")[0]).df()
+ df1 = self.conn.execute(
+ sqlglot.transpile(sql, read=dialect, write="duckdb")[0]
+ ).df()
df2 = self.conn.execute(optimized.sql(pretty=pretty, dialect="duckdb")).df()
assert_frame_equal(df1, df2)
@@ -268,7 +272,8 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
self.assertEqual(scopes[1].expression.sql(), "SELECT y.b FROM y")
self.assertEqual(scopes[2].expression.sql(), "(VALUES (1, 'test')) AS tab(cola, colb)")
self.assertEqual(
- scopes[3].expression.sql(), "SELECT cola, colb FROM (VALUES (1, 'test')) AS tab(cola, colb)"
+ scopes[3].expression.sql(),
+ "SELECT cola, colb FROM (VALUES (1, 'test')) AS tab(cola, colb)",
)
self.assertEqual(scopes[4].expression.sql(), "SELECT y.c AS b FROM y")
self.assertEqual(scopes[5].expression.sql(), "SELECT MAX(x.a) FROM x WHERE x.b = s.b")
@@ -287,7 +292,11 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
# Check that we can walk in scope from an arbitrary node
self.assertEqual(
- {node.sql() for node, *_ in walk_in_scope(expression.find(exp.Where)) if isinstance(node, exp.Column)},
+ {
+ node.sql()
+ for node, *_ in walk_in_scope(expression.find(exp.Where))
+ if isinstance(node, exp.Column)
+ },
{"s.b"},
)
@@ -324,7 +333,9 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
self.assertEqual(expression.args["to"].expressions[0].type, exp.DataType.Type.INT)
def test_cache_annotation(self):
- expression = annotate_types(parse_one("CACHE LAZY TABLE x OPTIONS('storageLevel' = 'value') AS SELECT 1"))
+ expression = annotate_types(
+ parse_one("CACHE LAZY TABLE x OPTIONS('storageLevel' = 'value') AS SELECT 1")
+ )
self.assertEqual(expression.expression.expressions[0].type, exp.DataType.Type.INT)
def test_binary_annotation(self):
@@ -384,7 +395,9 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
"""
expression = annotate_types(parse_one(sql), schema=schema)
- self.assertEqual(expression.expressions[0].type, exp.DataType.Type.TEXT) # tbl.cola + tbl.colb + 'foo' AS col
+ self.assertEqual(
+ expression.expressions[0].type, exp.DataType.Type.TEXT
+ ) # tbl.cola + tbl.colb + 'foo' AS col
outer_addition = expression.expressions[0].this # (tbl.cola + tbl.colb) + 'foo'
self.assertEqual(outer_addition.type, exp.DataType.Type.TEXT)
@@ -396,7 +409,9 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
self.assertEqual(inner_addition.right.type, exp.DataType.Type.TEXT)
cte_select = expression.args["with"].expressions[0].this
- self.assertEqual(cte_select.expressions[0].type, exp.DataType.Type.VARCHAR) # x.cola + 'bla' AS cola
+ self.assertEqual(
+ cte_select.expressions[0].type, exp.DataType.Type.VARCHAR
+ ) # x.cola + 'bla' AS cola
self.assertEqual(cte_select.expressions[1].type, exp.DataType.Type.TEXT) # y.colb AS colb
cte_select_addition = cte_select.expressions[0].this # x.cola + 'bla'
@@ -405,7 +420,9 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
self.assertEqual(cte_select_addition.right.type, exp.DataType.Type.VARCHAR)
# Check that x.cola AS cola and y.colb AS colb have types CHAR and TEXT, respectively
- for d, t in zip(cte_select.find_all(exp.Subquery), [exp.DataType.Type.CHAR, exp.DataType.Type.TEXT]):
+ for d, t in zip(
+ cte_select.find_all(exp.Subquery), [exp.DataType.Type.CHAR, exp.DataType.Type.TEXT]
+ ):
self.assertEqual(d.this.expressions[0].this.type, t)
def test_function_annotation(self):
@@ -421,6 +438,19 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
self.assertEqual(concat_expr.right.type, exp.DataType.Type.VARCHAR) # TRIM(x.colb)
self.assertEqual(concat_expr.right.this.type, exp.DataType.Type.CHAR) # x.colb
+ sql = "SELECT CASE WHEN 1=1 THEN x.cola ELSE x.colb END AS col FROM x AS x"
+
+ case_expr_alias = annotate_types(parse_one(sql), schema=schema).expressions[0]
+ self.assertEqual(case_expr_alias.type, exp.DataType.Type.VARCHAR)
+
+ case_expr = case_expr_alias.this
+ self.assertEqual(case_expr.type, exp.DataType.Type.VARCHAR)
+ self.assertEqual(case_expr.args["default"].type, exp.DataType.Type.CHAR)
+
+ case_ifs_expr = case_expr.args["ifs"][0]
+ self.assertEqual(case_ifs_expr.type, exp.DataType.Type.VARCHAR)
+ self.assertEqual(case_ifs_expr.args["true"].type, exp.DataType.Type.VARCHAR)
+
def test_unknown_annotation(self):
schema = {"x": {"cola": "VARCHAR"}}
sql = "SELECT x.cola || SOME_ANONYMOUS_FUNC(x.cola) AS col FROM x AS x"
@@ -431,8 +461,12 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
concat_expr = concat_expr_alias.this
self.assertEqual(concat_expr.type, exp.DataType.Type.UNKNOWN)
self.assertEqual(concat_expr.left.type, exp.DataType.Type.VARCHAR) # x.cola
- self.assertEqual(concat_expr.right.type, exp.DataType.Type.UNKNOWN) # SOME_ANONYMOUS_FUNC(x.cola)
- self.assertEqual(concat_expr.right.expressions[0].type, exp.DataType.Type.VARCHAR) # x.cola (arg)
+ self.assertEqual(
+ concat_expr.right.type, exp.DataType.Type.UNKNOWN
+ ) # SOME_ANONYMOUS_FUNC(x.cola)
+ self.assertEqual(
+ concat_expr.right.expressions[0].type, exp.DataType.Type.VARCHAR
+ ) # x.cola (arg)
def test_null_annotation(self):
expression = annotate_types(parse_one("SELECT NULL + 2 AS col")).expressions[0].this
diff --git a/tests/test_parser.py b/tests/test_parser.py
index 9afeae6..04c20b1 100644
--- a/tests/test_parser.py
+++ b/tests/test_parser.py
@@ -23,8 +23,6 @@ class TestParser(unittest.TestCase):
def test_float(self):
self.assertEqual(parse_one(".2"), parse_one("0.2"))
- self.assertEqual(parse_one("int 1"), parse_one("CAST(1 AS INT)"))
- self.assertEqual(parse_one("int.5"), parse_one("CAST(0.5 AS INT)"))
def test_table(self):
tables = [t.sql() for t in parse_one("select * from a, b.c, .d").find_all(exp.Table)]
@@ -33,7 +31,9 @@ class TestParser(unittest.TestCase):
def test_select(self):
self.assertIsNotNone(parse_one("select 1 natural"))
self.assertIsNotNone(parse_one("select * from (select 1) x order by x.y").args["order"])
- self.assertIsNotNone(parse_one("select * from x where a = (select 1) order by x.y").args["order"])
+ self.assertIsNotNone(
+ parse_one("select * from x where a = (select 1) order by x.y").args["order"]
+ )
self.assertEqual(len(parse_one("select * from (select 1) x cross join y").args["joins"]), 1)
self.assertEqual(
parse_one("""SELECT * FROM x CROSS JOIN y, z LATERAL VIEW EXPLODE(y)""").sql(),
@@ -125,26 +125,70 @@ class TestParser(unittest.TestCase):
def test_var(self):
self.assertEqual(parse_one("SELECT @JOIN, @'foo'").sql(), "SELECT @JOIN, @'foo'")
- def test_annotations(self):
+ def test_comments(self):
expression = parse_one(
"""
- SELECT
- a #annotation1,
- b as B #annotation2:testing ,
- "test#annotation",c#annotation3, d #annotation4,
- e #,
- f # space
+ --comment1
+ SELECT /* this won't be used */
+ a, --comment2
+ b as B, --comment3:testing
+ "test--annotation",
+ c, --comment4 --foo
+ e, --
+ f -- space
FROM foo
"""
)
- assert expression.expressions[0].name == "annotation1"
- assert expression.expressions[1].name == "annotation2:testing"
- assert expression.expressions[2].name == "test#annotation"
- assert expression.expressions[3].name == "annotation3"
- assert expression.expressions[4].name == "annotation4"
- assert expression.expressions[5].name == ""
- assert expression.expressions[6].name == "space"
+ self.assertEqual(expression.comment, "comment1")
+ self.assertEqual(expression.expressions[0].comment, "comment2")
+ self.assertEqual(expression.expressions[1].comment, "comment3:testing")
+ self.assertEqual(expression.expressions[2].comment, None)
+ self.assertEqual(expression.expressions[3].comment, "comment4 --foo")
+ self.assertEqual(expression.expressions[4].comment, "")
+ self.assertEqual(expression.expressions[5].comment, " space")
+
+ def test_type_literals(self):
+ self.assertEqual(parse_one("int 1"), parse_one("CAST(1 AS INT)"))
+ self.assertEqual(parse_one("int.5"), parse_one("CAST(0.5 AS INT)"))
+ self.assertEqual(
+ parse_one("TIMESTAMP '2022-01-01'").sql(), "CAST('2022-01-01' AS TIMESTAMP)"
+ )
+ self.assertEqual(
+ parse_one("TIMESTAMP(1) '2022-01-01'").sql(), "CAST('2022-01-01' AS TIMESTAMP(1))"
+ )
+ self.assertEqual(
+ parse_one("TIMESTAMP WITH TIME ZONE '2022-01-01'").sql(),
+ "CAST('2022-01-01' AS TIMESTAMPTZ)",
+ )
+ self.assertEqual(
+ parse_one("TIMESTAMP WITH LOCAL TIME ZONE '2022-01-01'").sql(),
+ "CAST('2022-01-01' AS TIMESTAMPLTZ)",
+ )
+ self.assertEqual(
+ parse_one("TIMESTAMP WITHOUT TIME ZONE '2022-01-01'").sql(),
+ "CAST('2022-01-01' AS TIMESTAMP)",
+ )
+ self.assertEqual(
+ parse_one("TIMESTAMP(1) WITH TIME ZONE '2022-01-01'").sql(),
+ "CAST('2022-01-01' AS TIMESTAMPTZ(1))",
+ )
+ self.assertEqual(
+ parse_one("TIMESTAMP(1) WITH LOCAL TIME ZONE '2022-01-01'").sql(),
+ "CAST('2022-01-01' AS TIMESTAMPLTZ(1))",
+ )
+ self.assertEqual(
+ parse_one("TIMESTAMP(1) WITHOUT TIME ZONE '2022-01-01'").sql(),
+ "CAST('2022-01-01' AS TIMESTAMP(1))",
+ )
+ self.assertEqual(parse_one("TIMESTAMP(1) WITH TIME ZONE").sql(), "TIMESTAMPTZ(1)")
+ self.assertEqual(parse_one("TIMESTAMP(1) WITH LOCAL TIME ZONE").sql(), "TIMESTAMPLTZ(1)")
+ self.assertEqual(parse_one("TIMESTAMP(1) WITHOUT TIME ZONE").sql(), "TIMESTAMP(1)")
+ self.assertEqual(parse_one("""JSON '{"x":"y"}'""").sql(), """CAST('{"x":"y"}' AS JSON)""")
+ self.assertIsInstance(parse_one("TIMESTAMP(1)"), exp.Func)
+ self.assertIsInstance(parse_one("TIMESTAMP('2022-01-01')"), exp.Func)
+ self.assertIsInstance(parse_one("TIMESTAMP()"), exp.Func)
+ self.assertIsInstance(parse_one("map.x"), exp.Column)
def test_pretty_config_override(self):
self.assertEqual(parse_one("SELECT col FROM x").sql(), "SELECT col FROM x")
diff --git a/tests/test_schema.py b/tests/test_schema.py
index bab97d8..cc0e3d1 100644
--- a/tests/test_schema.py
+++ b/tests/test_schema.py
@@ -1,281 +1,141 @@
import unittest
-from sqlglot import table
-from sqlglot.dataframe.sql import types as df_types
+from sqlglot import exp, to_table
+from sqlglot.errors import SchemaError
from sqlglot.schema import MappingSchema, ensure_schema
class TestSchema(unittest.TestCase):
- def test_schema(self):
- schema = ensure_schema(
- {
- "x": {
- "a": "uint64",
- }
- }
- )
- self.assertEqual(
- schema.column_names(
- table(
- "x",
- )
- ),
- ["a"],
- )
- with self.assertRaises(ValueError):
- schema.column_names(table("x", db="db", catalog="c"))
- with self.assertRaises(ValueError):
- schema.column_names(table("x", db="db"))
- with self.assertRaises(ValueError):
- schema.column_names(table("x2"))
+ def assert_column_names(self, schema, *table_results):
+ for table, result in table_results:
+ with self.subTest(f"{table} -> {result}"):
+ self.assertEqual(schema.column_names(to_table(table)), result)
- with self.assertRaises(ValueError):
- schema.add_table(table("y", db="db"), {"b": "string"})
- with self.assertRaises(ValueError):
- schema.add_table(table("y", db="db", catalog="c"), {"b": "string"})
+ def assert_column_names_raises(self, schema, *tables):
+ for table in tables:
+ with self.subTest(table):
+ with self.assertRaises(SchemaError):
+ schema.column_names(to_table(table))
- schema.add_table(table("y"), {"b": "string"})
- schema_with_y = {
- "x": {
- "a": "uint64",
- },
- "y": {
- "b": "string",
- },
- }
- self.assertEqual(schema.schema, schema_with_y)
-
- new_schema = schema.copy()
- new_schema.add_table(table("z"), {"c": "string"})
- self.assertEqual(schema.schema, schema_with_y)
- self.assertEqual(
- new_schema.schema,
- {
- "x": {
- "a": "uint64",
- },
- "y": {
- "b": "string",
- },
- "z": {
- "c": "string",
- },
- },
- )
- schema.add_table(table("m"), {"d": "string"})
- schema.add_table(table("n"), {"e": "string"})
- schema_with_m_n = {
- "x": {
- "a": "uint64",
- },
- "y": {
- "b": "string",
- },
- "m": {
- "d": "string",
- },
- "n": {
- "e": "string",
- },
- }
- self.assertEqual(schema.schema, schema_with_m_n)
- new_schema = schema.copy()
- new_schema.add_table(table("o"), {"f": "string"})
- new_schema.add_table(table("p"), {"g": "string"})
- self.assertEqual(schema.schema, schema_with_m_n)
- self.assertEqual(
- new_schema.schema,
+ def test_schema(self):
+ schema = ensure_schema(
{
"x": {
"a": "uint64",
},
"y": {
- "b": "string",
- },
- "m": {
- "d": "string",
- },
- "n": {
- "e": "string",
- },
- "o": {
- "f": "string",
- },
- "p": {
- "g": "string",
+ "b": "uint64",
+ "c": "uint64",
},
},
)
- schema = ensure_schema(
- {
- "db": {
- "x": {
- "a": "uint64",
- }
- }
- }
+ self.assert_column_names(
+ schema,
+ ("x", ["a"]),
+ ("y", ["b", "c"]),
+ ("z.x", ["a"]),
+ ("z.x.y", ["b", "c"]),
)
- self.assertEqual(schema.column_names(table("x", db="db")), ["a"])
- with self.assertRaises(ValueError):
- schema.column_names(table("x", db="db", catalog="c"))
- with self.assertRaises(ValueError):
- schema.column_names(table("x"))
- with self.assertRaises(ValueError):
- schema.column_names(table("x", db="db2"))
- with self.assertRaises(ValueError):
- schema.column_names(table("x2", db="db"))
- with self.assertRaises(ValueError):
- schema.add_table(table("y"), {"b": "string"})
- with self.assertRaises(ValueError):
- schema.add_table(table("y", db="db", catalog="c"), {"b": "string"})
+ self.assert_column_names_raises(
+ schema,
+ "z",
+ "z.z",
+ "z.z.z",
+ )
- schema.add_table(table("y", db="db"), {"b": "string"})
- self.assertEqual(
- schema.schema,
+ def test_schema_db(self):
+ schema = ensure_schema(
{
- "db": {
+ "d1": {
"x": {
"a": "uint64",
},
"y": {
- "b": "string",
+ "b": "uint64",
+ },
+ },
+ "d2": {
+ "x": {
+ "c": "uint64",
},
- }
+ },
},
)
- schema = ensure_schema(
- {
- "c": {
- "db": {
- "x": {
- "a": "uint64",
- }
- }
- }
- }
+ self.assert_column_names(
+ schema,
+ ("d1.x", ["a"]),
+ ("d2.x", ["c"]),
+ ("y", ["b"]),
+ ("d1.y", ["b"]),
+ ("z.d1.y", ["b"]),
)
- self.assertEqual(schema.column_names(table("x", db="db", catalog="c")), ["a"])
- with self.assertRaises(ValueError):
- schema.column_names(table("x", db="db"))
- with self.assertRaises(ValueError):
- schema.column_names(table("x"))
- with self.assertRaises(ValueError):
- schema.column_names(table("x", db="db", catalog="c2"))
- with self.assertRaises(ValueError):
- schema.column_names(table("x", db="db2"))
- with self.assertRaises(ValueError):
- schema.column_names(table("x2", db="db"))
- with self.assertRaises(ValueError):
- schema.add_table(table("x"), {"b": "string"})
- with self.assertRaises(ValueError):
- schema.add_table(table("x", db="db"), {"b": "string"})
+ self.assert_column_names_raises(
+ schema,
+ "x",
+ "z.x",
+ "z.y",
+ )
- schema.add_table(table("y", db="db", catalog="c"), {"a": "string", "b": "int"})
- self.assertEqual(
- schema.schema,
+ def test_schema_catalog(self):
+ schema = ensure_schema(
{
- "c": {
- "db": {
+ "c1": {
+ "d1": {
"x": {
"a": "uint64",
},
"y": {
- "a": "string",
- "b": "int",
+ "b": "uint64",
},
- }
- }
- },
- )
- schema.add_table(table("z", db="db2", catalog="c"), {"c": "string", "d": "int"})
- self.assertEqual(
- schema.schema,
- {
- "c": {
- "db": {
- "x": {
- "a": "uint64",
+ "z": {
+ "c": "uint64",
},
+ },
+ },
+ "c2": {
+ "d1": {
"y": {
- "a": "string",
- "b": "int",
+ "d": "uint64",
},
- },
- "db2": {
"z": {
- "c": "string",
- "d": "int",
- }
- },
- }
- },
- )
- schema.add_table(table("m", db="db2", catalog="c2"), {"e": "string", "f": "int"})
- self.assertEqual(
- schema.schema,
- {
- "c": {
- "db": {
- "x": {
- "a": "uint64",
- },
- "y": {
- "a": "string",
- "b": "int",
+ "e": "uint64",
},
},
- "db2": {
+ "d2": {
"z": {
- "c": "string",
- "d": "int",
- }
+ "f": "uint64",
+ },
},
},
- "c2": {
- "db2": {
- "m": {
- "e": "string",
- "f": "int",
- }
- }
- },
- },
- )
-
- schema = ensure_schema(
- {
- "x": {
- "a": "uint64",
- }
}
)
- self.assertEqual(schema.column_names(table("x")), ["a"])
- schema = MappingSchema()
- schema.add_table(table("x"), {"a": "string"})
- self.assertEqual(
- schema.schema,
- {
- "x": {
- "a": "string",
- }
- },
+ self.assert_column_names(
+ schema,
+ ("x", ["a"]),
+ ("d1.x", ["a"]),
+ ("c1.d1.x", ["a"]),
+ ("c1.d1.y", ["b"]),
+ ("c1.d1.z", ["c"]),
+ ("c2.d1.y", ["d"]),
+ ("c2.d1.z", ["e"]),
+ ("d2.z", ["f"]),
+ ("c2.d2.z", ["f"]),
)
- schema.add_table(table("y"), df_types.StructType([df_types.StructField("b", df_types.StringType())]))
- self.assertEqual(
- schema.schema,
- {
- "x": {
- "a": "string",
- },
- "y": {
- "b": "string",
- },
- },
+
+ self.assert_column_names_raises(
+ schema,
+ "q",
+ "d2.x",
+ "y",
+ "z",
+ "d1.y",
+ "d1.z",
+ "a.b.c",
)
def test_schema_add_table_with_and_without_mapping(self):
@@ -288,3 +148,34 @@ class TestSchema(unittest.TestCase):
self.assertEqual(schema.column_names("test"), ["x", "y"])
schema.add_table("test")
self.assertEqual(schema.column_names("test"), ["x", "y"])
+
+ def test_schema_get_column_type(self):
+ schema = MappingSchema({"a": {"b": "varchar"}})
+ self.assertEqual(schema.get_column_type("a", "b"), exp.DataType.Type.VARCHAR)
+ self.assertEqual(
+ schema.get_column_type(exp.Table(this="a"), exp.Column(this="b")),
+ exp.DataType.Type.VARCHAR,
+ )
+ self.assertEqual(
+ schema.get_column_type("a", exp.Column(this="b")), exp.DataType.Type.VARCHAR
+ )
+ self.assertEqual(
+ schema.get_column_type(exp.Table(this="a"), "b"), exp.DataType.Type.VARCHAR
+ )
+ schema = MappingSchema({"a": {"b": {"c": "varchar"}}})
+ self.assertEqual(
+ schema.get_column_type(exp.Table(this="b", db="a"), exp.Column(this="c")),
+ exp.DataType.Type.VARCHAR,
+ )
+ self.assertEqual(
+ schema.get_column_type(exp.Table(this="b", db="a"), "c"), exp.DataType.Type.VARCHAR
+ )
+ schema = MappingSchema({"a": {"b": {"c": {"d": "varchar"}}}})
+ self.assertEqual(
+ schema.get_column_type(exp.Table(this="c", db="b", catalog="a"), exp.Column(this="d")),
+ exp.DataType.Type.VARCHAR,
+ )
+ self.assertEqual(
+ schema.get_column_type(exp.Table(this="c", db="b", catalog="a"), "d"),
+ exp.DataType.Type.VARCHAR,
+ )
diff --git a/tests/test_tokens.py b/tests/test_tokens.py
new file mode 100644
index 0000000..943c2b0
--- /dev/null
+++ b/tests/test_tokens.py
@@ -0,0 +1,18 @@
+import unittest
+
+from sqlglot.tokens import Tokenizer
+
+
+class TestTokens(unittest.TestCase):
+ def test_comment_attachment(self):
+ tokenizer = Tokenizer()
+ sql_comment = [
+ ("/*comment*/ foo", "comment"),
+ ("/*comment*/ foo --test", "comment"),
+ ("--comment\nfoo --test", "comment"),
+ ("foo --comment", "comment"),
+ ("foo", None),
+ ]
+
+ for sql, comment in sql_comment:
+ self.assertEqual(tokenizer.tokenize(sql)[0].comment, comment)
diff --git a/tests/test_transpile.py b/tests/test_transpile.py
index 01b8205..942053e 100644
--- a/tests/test_transpile.py
+++ b/tests/test_transpile.py
@@ -49,6 +49,12 @@ class TestTranspile(unittest.TestCase):
leading_comma=True,
pretty=True,
)
+ self.validate(
+ "SELECT FOO, /*x*/\nBAR, /*y*/\nBAZ",
+ "SELECT\n FOO -- x\n , BAR -- y\n , BAZ",
+ leading_comma=True,
+ pretty=True,
+ )
# without pretty, this should be a no-op
self.validate(
"SELECT FOO, BAR, BAZ",
@@ -63,24 +69,61 @@ class TestTranspile(unittest.TestCase):
self.validate("SELECT 3>=3", "SELECT 3 >= 3")
def test_comments(self):
- self.validate("SELECT 1 FROM foo -- comment", "SELECT 1 FROM foo")
- self.validate("SELECT 1 /* inline */ FROM foo -- comment", "SELECT 1 FROM foo")
-
+ self.validate("SELECT 1 FROM foo -- comment", "SELECT 1 FROM foo /* comment */")
+ self.validate("SELECT --+5\nx FROM foo", "/* +5 */ SELECT x FROM foo")
+ self.validate("SELECT --!5\nx FROM foo", "/* !5 */ SELECT x FROM foo")
+ self.validate(
+ "SELECT 1 /* inline */ FROM foo -- comment",
+ "SELECT 1 /* inline */ FROM foo /* comment */",
+ )
+ self.validate(
+ "SELECT FUN(x) /*x*/, [1,2,3] /*y*/", "SELECT FUN(x) /* x */, ARRAY(1, 2, 3) /* y */"
+ )
self.validate(
"""
SELECT 1 -- comment
FROM foo -- comment
""",
- "SELECT 1 FROM foo",
+ "SELECT 1 /* comment */ FROM foo /* comment */",
)
-
self.validate(
"""
SELECT 1 /* big comment
like this */
FROM foo -- comment
""",
- "SELECT 1 FROM foo",
+ """SELECT 1 /* big comment
+ like this */ FROM foo /* comment */""",
+ )
+ self.validate(
+ "select x from foo -- x",
+ "SELECT x FROM foo /* x */",
+ )
+ self.validate(
+ """
+ /* multi
+ line
+ comment
+ */
+ SELECT
+ tbl.cola /* comment 1 */ + tbl.colb /* comment 2 */,
+ CAST(x AS INT), # comment 3
+ y -- comment 4
+ FROM
+ bar /* comment 5 */,
+ tbl # comment 6
+ """,
+ """/* multi
+ line
+ comment
+ */
+SELECT
+ tbl.cola /* comment 1 */ + tbl.colb /* comment 2 */,
+ CAST(x AS INT), -- comment 3
+ y -- comment 4
+FROM bar /* comment 5 */, tbl /* comment 6 */""",
+ read="mysql",
+ pretty=True,
)
def test_types(self):
@@ -146,6 +189,16 @@ class TestTranspile(unittest.TestCase):
def test_ignore_nulls(self):
self.validate("SELECT COUNT(x RESPECT NULLS)", "SELECT COUNT(x)")
+ def test_with(self):
+ self.validate(
+ "WITH a AS (SELECT 1) WITH b AS (SELECT 2) SELECT *",
+ "WITH a AS (SELECT 1), b AS (SELECT 2) SELECT *",
+ )
+ self.validate(
+ "WITH a AS (SELECT 1), WITH b AS (SELECT 2) SELECT *",
+ "WITH a AS (SELECT 1), b AS (SELECT 2) SELECT *",
+ )
+
def test_time(self):
self.validate("TIMESTAMP '2020-01-01'", "CAST('2020-01-01' AS TIMESTAMP)")
self.validate("TIMESTAMP WITH TIME ZONE '2020-01-01'", "CAST('2020-01-01' AS TIMESTAMPTZ)")