From 8b4272814fb4585be120f183eb7c26bb8acde974 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Fri, 21 Oct 2022 11:29:26 +0200 Subject: Merging upstream version 9.0.1. Signed-off-by: Daniel Baumann --- tests/dataframe/__init__.py | 0 tests/dataframe/integration/__init__.py | 0 tests/dataframe/integration/dataframe_validator.py | 149 ++ tests/dataframe/integration/test_dataframe.py | 1103 ++++++++++++++ tests/dataframe/integration/test_grouped_data.py | 71 + tests/dataframe/integration/test_session.py | 28 + tests/dataframe/unit/__init__.py | 0 tests/dataframe/unit/dataframe_sql_validator.py | 35 + tests/dataframe/unit/test_column.py | 167 ++ tests/dataframe/unit/test_dataframe.py | 39 + tests/dataframe/unit/test_dataframe_writer.py | 86 ++ tests/dataframe/unit/test_functions.py | 1593 ++++++++++++++++++++ tests/dataframe/unit/test_session.py | 114 ++ tests/dataframe/unit/test_types.py | 70 + tests/dataframe/unit/test_window.py | 60 + tests/dialects/test_dialect.py | 25 +- tests/dialects/test_duckdb.py | 23 +- tests/dialects/test_mysql.py | 18 + tests/dialects/test_postgres.py | 10 + tests/dialects/test_presto.py | 4 +- tests/dialects/test_snowflake.py | 23 +- tests/dialects/test_tsql.py | 26 + tests/fixtures/identity.sql | 7 + tests/fixtures/optimizer/merge_subqueries.sql | 24 + tests/fixtures/optimizer/pushdown_predicates.sql | 3 + tests/fixtures/optimizer/pushdown_projections.sql | 3 + tests/fixtures/optimizer/qualify_columns.sql | 3 + tests/helpers.py | 2 + tests/test_executor.py | 8 +- tests/test_expressions.py | 25 +- tests/test_optimizer.py | 119 +- tests/test_parser.py | 6 +- tests/test_schema.py | 290 ++++ 33 files changed, 4000 insertions(+), 134 deletions(-) create mode 100644 tests/dataframe/__init__.py create mode 100644 tests/dataframe/integration/__init__.py create mode 100644 tests/dataframe/integration/dataframe_validator.py create mode 100644 tests/dataframe/integration/test_dataframe.py create mode 100644 tests/dataframe/integration/test_grouped_data.py create mode 100644 tests/dataframe/integration/test_session.py create mode 100644 tests/dataframe/unit/__init__.py create mode 100644 tests/dataframe/unit/dataframe_sql_validator.py create mode 100644 tests/dataframe/unit/test_column.py create mode 100644 tests/dataframe/unit/test_dataframe.py create mode 100644 tests/dataframe/unit/test_dataframe_writer.py create mode 100644 tests/dataframe/unit/test_functions.py create mode 100644 tests/dataframe/unit/test_session.py create mode 100644 tests/dataframe/unit/test_types.py create mode 100644 tests/dataframe/unit/test_window.py create mode 100644 tests/test_schema.py (limited to 'tests') diff --git a/tests/dataframe/__init__.py b/tests/dataframe/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/dataframe/integration/__init__.py b/tests/dataframe/integration/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/dataframe/integration/dataframe_validator.py b/tests/dataframe/integration/dataframe_validator.py new file mode 100644 index 0000000..6c4642f --- /dev/null +++ b/tests/dataframe/integration/dataframe_validator.py @@ -0,0 +1,149 @@ +import typing as t +import unittest +import warnings + +import sqlglot +from tests.helpers import SKIP_INTEGRATION + +if t.TYPE_CHECKING: + from pyspark.sql import DataFrame as SparkDataFrame + + +@unittest.skipIf(SKIP_INTEGRATION, "Skipping Integration Tests since `SKIP_INTEGRATION` is set") +class DataFrameValidator(unittest.TestCase): + spark = None + sqlglot = None + df_employee = None + df_store = None + df_district = None + spark_employee_schema = None + sqlglot_employee_schema = None + spark_store_schema = None + sqlglot_store_schema = None + spark_district_schema = None + sqlglot_district_schema = None + + @classmethod + def setUpClass(cls): + from pyspark import SparkConf + from pyspark.sql import SparkSession, types + + from sqlglot.dataframe.sql import types as sqlglotSparkTypes + from sqlglot.dataframe.sql.session import SparkSession as SqlglotSparkSession + + # This is for test `test_branching_root_dataframes` + config = SparkConf().setAll([("spark.sql.analyzer.failAmbiguousSelfJoin", "false")]) + cls.spark = SparkSession.builder.master("local[*]").appName("Unit-tests").config(conf=config).getOrCreate() + cls.spark.sparkContext.setLogLevel("ERROR") + cls.sqlglot = SqlglotSparkSession() + cls.spark_employee_schema = types.StructType( + [ + types.StructField("employee_id", types.IntegerType(), False), + types.StructField("fname", types.StringType(), False), + types.StructField("lname", types.StringType(), False), + types.StructField("age", types.IntegerType(), False), + types.StructField("store_id", types.IntegerType(), False), + ] + ) + cls.sqlglot_employee_schema = sqlglotSparkTypes.StructType( + [ + sqlglotSparkTypes.StructField("employee_id", sqlglotSparkTypes.IntegerType(), False), + sqlglotSparkTypes.StructField("fname", sqlglotSparkTypes.StringType(), False), + sqlglotSparkTypes.StructField("lname", sqlglotSparkTypes.StringType(), False), + sqlglotSparkTypes.StructField("age", sqlglotSparkTypes.IntegerType(), False), + sqlglotSparkTypes.StructField("store_id", sqlglotSparkTypes.IntegerType(), False), + ] + ) + employee_data = [ + (1, "Jack", "Shephard", 37, 1), + (2, "John", "Locke", 65, 1), + (3, "Kate", "Austen", 37, 2), + (4, "Claire", "Littleton", 27, 2), + (5, "Hugo", "Reyes", 29, 100), + ] + cls.df_employee = cls.spark.createDataFrame(data=employee_data, schema=cls.spark_employee_schema) + cls.dfs_employee = cls.sqlglot.createDataFrame(data=employee_data, schema=cls.sqlglot_employee_schema) + cls.df_employee.createOrReplaceTempView("employee") + + cls.spark_store_schema = types.StructType( + [ + types.StructField("store_id", types.IntegerType(), False), + types.StructField("store_name", types.StringType(), False), + types.StructField("district_id", types.IntegerType(), False), + types.StructField("num_sales", types.IntegerType(), False), + ] + ) + cls.sqlglot_store_schema = sqlglotSparkTypes.StructType( + [ + sqlglotSparkTypes.StructField("store_id", sqlglotSparkTypes.IntegerType(), False), + sqlglotSparkTypes.StructField("store_name", sqlglotSparkTypes.StringType(), False), + sqlglotSparkTypes.StructField("district_id", sqlglotSparkTypes.IntegerType(), False), + sqlglotSparkTypes.StructField("num_sales", sqlglotSparkTypes.IntegerType(), False), + ] + ) + store_data = [ + (1, "Hydra", 1, 37), + (2, "Arrow", 2, 2000), + ] + cls.df_store = cls.spark.createDataFrame(data=store_data, schema=cls.spark_store_schema) + cls.dfs_store = cls.sqlglot.createDataFrame(data=store_data, schema=cls.sqlglot_store_schema) + cls.df_store.createOrReplaceTempView("store") + + cls.spark_district_schema = types.StructType( + [ + types.StructField("district_id", types.IntegerType(), False), + types.StructField("district_name", types.StringType(), False), + types.StructField("manager_name", types.StringType(), False), + ] + ) + cls.sqlglot_district_schema = sqlglotSparkTypes.StructType( + [ + sqlglotSparkTypes.StructField("district_id", sqlglotSparkTypes.IntegerType(), False), + sqlglotSparkTypes.StructField("district_name", sqlglotSparkTypes.StringType(), False), + sqlglotSparkTypes.StructField("manager_name", sqlglotSparkTypes.StringType(), False), + ] + ) + district_data = [ + (1, "Temple", "Dogen"), + (2, "Lighthouse", "Jacob"), + ] + cls.df_district = cls.spark.createDataFrame(data=district_data, schema=cls.spark_district_schema) + cls.dfs_district = cls.sqlglot.createDataFrame(data=district_data, schema=cls.sqlglot_district_schema) + cls.df_district.createOrReplaceTempView("district") + sqlglot.schema.add_table("employee", cls.sqlglot_employee_schema) + sqlglot.schema.add_table("store", cls.sqlglot_store_schema) + sqlglot.schema.add_table("district", cls.sqlglot_district_schema) + + def setUp(self) -> None: + warnings.filterwarnings("ignore", category=ResourceWarning) + self.df_spark_store = self.df_store.alias("df_store") # type: ignore + self.df_spark_employee = self.df_employee.alias("df_employee") # type: ignore + self.df_spark_district = self.df_district.alias("df_district") # type: ignore + self.df_sqlglot_store = self.dfs_store.alias("store") # type: ignore + self.df_sqlglot_employee = self.dfs_employee.alias("employee") # type: ignore + self.df_sqlglot_district = self.dfs_district.alias("district") # type: ignore + + def compare_spark_with_sqlglot( + self, df_spark, df_sqlglot, no_empty=True, skip_schema_compare=False + ) -> t.Tuple["SparkDataFrame", "SparkDataFrame"]: + def compare_schemas(schema_1, schema_2): + for schema in [schema_1, schema_2]: + for struct_field in schema.fields: + struct_field.metadata = {} + self.assertEqual(schema_1, schema_2) + + for statement in df_sqlglot.sql(): + actual_df_sqlglot = self.spark.sql(statement) # type: ignore + df_sqlglot_results = actual_df_sqlglot.collect() + df_spark_results = df_spark.collect() + if not skip_schema_compare: + compare_schemas(df_spark.schema, actual_df_sqlglot.schema) + self.assertEqual(df_spark_results, df_sqlglot_results) + if no_empty: + self.assertNotEqual(len(df_spark_results), 0) + self.assertNotEqual(len(df_sqlglot_results), 0) + return df_spark, actual_df_sqlglot + + @classmethod + def get_explain_plan(cls, df: "SparkDataFrame", mode: str = "extended") -> str: + return df._sc._jvm.PythonSQLUtils.explainString(df._jdf.queryExecution(), mode) # type: ignore diff --git a/tests/dataframe/integration/test_dataframe.py b/tests/dataframe/integration/test_dataframe.py new file mode 100644 index 0000000..c740bec --- /dev/null +++ b/tests/dataframe/integration/test_dataframe.py @@ -0,0 +1,1103 @@ +from pyspark.sql import functions as F + +from sqlglot.dataframe.sql import functions as SF +from tests.dataframe.integration.dataframe_validator import DataFrameValidator + + +class TestDataframeFunc(DataFrameValidator): + def test_simple_select(self): + df_employee = self.df_spark_employee.select(F.col("employee_id")) + dfs_employee = self.df_sqlglot_employee.select(SF.col("employee_id")) + self.compare_spark_with_sqlglot(df_employee, dfs_employee) + + def test_simple_select_from_table(self): + df = self.df_spark_employee + dfs = self.sqlglot.read.table("employee") + self.compare_spark_with_sqlglot(df, dfs) + + def test_simple_select_df_attribute(self): + df_employee = self.df_spark_employee.select(self.df_spark_employee.employee_id) + dfs_employee = self.df_sqlglot_employee.select(self.df_sqlglot_employee.employee_id) + self.compare_spark_with_sqlglot(df_employee, dfs_employee) + + def test_simple_select_df_dict(self): + df_employee = self.df_spark_employee.select(self.df_spark_employee["employee_id"]) + dfs_employee = self.df_sqlglot_employee.select(self.df_sqlglot_employee["employee_id"]) + self.compare_spark_with_sqlglot(df_employee, dfs_employee) + + def test_multiple_selects(self): + df_employee = self.df_spark_employee.select( + self.df_spark_employee["employee_id"], F.col("fname"), self.df_spark_employee.lname + ) + dfs_employee = self.df_sqlglot_employee.select( + self.df_sqlglot_employee["employee_id"], SF.col("fname"), self.df_sqlglot_employee.lname + ) + self.compare_spark_with_sqlglot(df_employee, dfs_employee) + + def test_alias_no_op(self): + df_employee = self.df_spark_employee.alias("df_employee") + dfs_employee = self.df_sqlglot_employee.alias("dfs_employee") + self.compare_spark_with_sqlglot(df_employee, dfs_employee) + + def test_alias_with_select(self): + df_employee = self.df_spark_employee.alias("df_employee").select( + self.df_spark_employee["employee_id"], F.col("df_employee.fname"), self.df_spark_employee.lname + ) + dfs_employee = self.df_sqlglot_employee.alias("dfs_employee").select( + self.df_sqlglot_employee["employee_id"], SF.col("dfs_employee.fname"), self.df_sqlglot_employee.lname + ) + self.compare_spark_with_sqlglot(df_employee, dfs_employee) + + def test_case_when_otherwise(self): + df = self.df_spark_employee.select( + F.when((F.col("age") >= F.lit(40)) & (F.col("age") <= F.lit(60)), F.lit("between 40 and 60")) + .when(F.col("age") < F.lit(40), "less than 40") + .otherwise("greater than 60") + ) + + dfs = self.df_sqlglot_employee.select( + SF.when((SF.col("age") >= SF.lit(40)) & (SF.col("age") <= SF.lit(60)), SF.lit("between 40 and 60")) + .when(SF.col("age") < SF.lit(40), "less than 40") + .otherwise("greater than 60") + ) + + self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) + + def test_case_when_no_otherwise(self): + df = self.df_spark_employee.select( + F.when((F.col("age") >= F.lit(40)) & (F.col("age") <= F.lit(60)), F.lit("between 40 and 60")).when( + F.col("age") < F.lit(40), "less than 40" + ) + ) + + dfs = self.df_sqlglot_employee.select( + SF.when((SF.col("age") >= SF.lit(40)) & (SF.col("age") <= SF.lit(60)), SF.lit("between 40 and 60")).when( + SF.col("age") < SF.lit(40), "less than 40" + ) + ) + + self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) + + def test_where_clause_single(self): + df_employee = self.df_spark_employee.where(F.col("age") == F.lit(37)) + dfs_employee = self.df_sqlglot_employee.where(SF.col("age") == SF.lit(37)) + self.compare_spark_with_sqlglot(df_employee, dfs_employee) + + def test_where_clause_multiple_and(self): + df_employee = self.df_spark_employee.where((F.col("age") == F.lit(37)) & (F.col("fname") == F.lit("Jack"))) + dfs_employee = self.df_sqlglot_employee.where( + (SF.col("age") == SF.lit(37)) & (SF.col("fname") == SF.lit("Jack")) + ) + self.compare_spark_with_sqlglot(df_employee, dfs_employee) + + def test_where_many_and(self): + df_employee = self.df_spark_employee.where( + (F.col("age") == F.lit(37)) + & (F.col("fname") == F.lit("Jack")) + & (F.col("lname") == F.lit("Shephard")) + & (F.col("employee_id") == F.lit(1)) + ) + dfs_employee = self.df_sqlglot_employee.where( + (SF.col("age") == SF.lit(37)) + & (SF.col("fname") == SF.lit("Jack")) + & (SF.col("lname") == SF.lit("Shephard")) + & (SF.col("employee_id") == SF.lit(1)) + ) + self.compare_spark_with_sqlglot(df_employee, dfs_employee) + + def test_where_clause_multiple_or(self): + df_employee = self.df_spark_employee.where((F.col("age") == F.lit(37)) | (F.col("fname") == F.lit("Kate"))) + dfs_employee = self.df_sqlglot_employee.where( + (SF.col("age") == SF.lit(37)) | (SF.col("fname") == SF.lit("Kate")) + ) + self.compare_spark_with_sqlglot(df_employee, dfs_employee) + + def test_where_many_or(self): + df_employee = self.df_spark_employee.where( + (F.col("age") == F.lit(37)) + | (F.col("fname") == F.lit("Kate")) + | (F.col("lname") == F.lit("Littleton")) + | (F.col("employee_id") == F.lit(2)) + ) + dfs_employee = self.df_sqlglot_employee.where( + (SF.col("age") == SF.lit(37)) + | (SF.col("fname") == SF.lit("Kate")) + | (SF.col("lname") == SF.lit("Littleton")) + | (SF.col("employee_id") == SF.lit(2)) + ) + self.compare_spark_with_sqlglot(df_employee, dfs_employee) + + def test_where_mixed_and_or(self): + df_employee = self.df_spark_employee.where( + ((F.col("age") == F.lit(65)) & (F.col("fname") == F.lit("John"))) + | ((F.col("lname") == F.lit("Shephard")) & (F.col("age") == F.lit(37))) + ) + dfs_employee = self.df_sqlglot_employee.where( + ((SF.col("age") == SF.lit(65)) & (SF.col("fname") == SF.lit("John"))) + | ((SF.col("lname") == SF.lit("Shephard")) & (SF.col("age") == SF.lit(37))) + ) + self.compare_spark_with_sqlglot(df_employee, dfs_employee) + + def test_where_multiple_chained(self): + df_employee = self.df_spark_employee.where(F.col("age") == F.lit(37)).where( + self.df_spark_employee.fname == F.lit("Jack") + ) + dfs_employee = self.df_sqlglot_employee.where(SF.col("age") == SF.lit(37)).where( + self.df_sqlglot_employee.fname == SF.lit("Jack") + ) + self.compare_spark_with_sqlglot(df_employee, dfs_employee) + + def test_operators(self): + df_employee = self.df_spark_employee.where(self.df_spark_employee["age"] < F.lit(50)) + dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] < SF.lit(50)) + self.compare_spark_with_sqlglot(df_employee, dfs_employee) + + df_employee = self.df_spark_employee.where(self.df_spark_employee["age"] <= F.lit(37)) + dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] <= SF.lit(37)) + self.compare_spark_with_sqlglot(df_employee, dfs_employee) + + df_employee = self.df_spark_employee.where(self.df_spark_employee["age"] > F.lit(50)) + dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] > SF.lit(50)) + self.compare_spark_with_sqlglot(df_employee, dfs_employee) + + df_employee = self.df_spark_employee.where(self.df_spark_employee["age"] >= F.lit(37)) + dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] >= SF.lit(37)) + self.compare_spark_with_sqlglot(df_employee, dfs_employee) + + df_employee = self.df_spark_employee.where(self.df_spark_employee["age"] != F.lit(50)) + dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] != SF.lit(50)) + self.compare_spark_with_sqlglot(df_employee, dfs_employee) + + df_employee = self.df_spark_employee.where(self.df_spark_employee["age"] == F.lit(37)) + dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] == SF.lit(37)) + self.compare_spark_with_sqlglot(df_employee, dfs_employee) + + df_employee = self.df_spark_employee.where(self.df_spark_employee["age"] % F.lit(5) == F.lit(0)) + dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] % SF.lit(5) == SF.lit(0)) + self.compare_spark_with_sqlglot(df_employee, dfs_employee) + + df_employee = self.df_spark_employee.where(self.df_spark_employee["age"] + F.lit(5) > F.lit(28)) + dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] + SF.lit(5) > SF.lit(28)) + self.compare_spark_with_sqlglot(df_employee, dfs_employee) + + df_employee = self.df_spark_employee.where(self.df_spark_employee["age"] - F.lit(5) > F.lit(28)) + dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] - SF.lit(5) > SF.lit(28)) + self.compare_spark_with_sqlglot(df_employee, dfs_employee) + + df_employee = self.df_spark_employee.where( + self.df_spark_employee["age"] * F.lit(0.5) == self.df_spark_employee["age"] / F.lit(2) + ) + dfs_employee = self.df_sqlglot_employee.where( + self.df_sqlglot_employee["age"] * SF.lit(0.5) == self.df_sqlglot_employee["age"] / SF.lit(2) + ) + self.compare_spark_with_sqlglot(df_employee, dfs_employee) + + def test_join_inner(self): + df_joined = self.df_spark_employee.join(self.df_spark_store, on=["store_id"], how="inner").select( + self.df_spark_employee.employee_id, + self.df_spark_employee["fname"], + F.col("lname"), + F.col("age"), + F.col("store_id"), + self.df_spark_store.store_name, + self.df_spark_store["num_sales"], + ) + dfs_joined = self.df_sqlglot_employee.join(self.df_sqlglot_store, on=["store_id"], how="inner").select( + self.df_sqlglot_employee.employee_id, + self.df_sqlglot_employee["fname"], + SF.col("lname"), + SF.col("age"), + SF.col("store_id"), + self.df_sqlglot_store.store_name, + self.df_sqlglot_store["num_sales"], + ) + self.compare_spark_with_sqlglot(df_joined, dfs_joined) + + def test_join_inner_no_select(self): + df_joined = self.df_spark_employee.select(F.col("store_id"), F.col("fname"), F.col("lname")).join( + self.df_spark_store.select(F.col("store_id"), F.col("store_name")), on=["store_id"], how="inner" + ) + dfs_joined = self.df_sqlglot_employee.select(SF.col("store_id"), SF.col("fname"), SF.col("lname")).join( + self.df_sqlglot_store.select(SF.col("store_id"), SF.col("store_name")), on=["store_id"], how="inner" + ) + self.compare_spark_with_sqlglot(df_joined, dfs_joined) + + def test_join_inner_equality_single(self): + df_joined = self.df_spark_employee.join( + self.df_spark_store, on=self.df_spark_employee.store_id == self.df_spark_store.store_id, how="inner" + ).select( + self.df_spark_employee.employee_id, + self.df_spark_employee["fname"], + F.col("lname"), + F.col("age"), + self.df_spark_employee.store_id, + self.df_spark_store.store_name, + self.df_spark_store["num_sales"], + ) + dfs_joined = self.df_sqlglot_employee.join( + self.df_sqlglot_store, on=self.df_sqlglot_employee.store_id == self.df_sqlglot_store.store_id, how="inner" + ).select( + self.df_sqlglot_employee.employee_id, + self.df_sqlglot_employee["fname"], + SF.col("lname"), + SF.col("age"), + self.df_sqlglot_employee.store_id, + self.df_sqlglot_store.store_name, + self.df_sqlglot_store["num_sales"], + ) + self.compare_spark_with_sqlglot(df_joined, dfs_joined) + + def test_join_inner_equality_multiple(self): + df_joined = self.df_spark_employee.join( + self.df_spark_store, + on=[ + self.df_spark_employee.store_id == self.df_spark_store.store_id, + self.df_spark_employee.age == self.df_spark_store.num_sales, + ], + how="inner", + ).select( + self.df_spark_employee.employee_id, + self.df_spark_employee["fname"], + F.col("lname"), + F.col("age"), + self.df_spark_employee.store_id, + self.df_spark_store.store_name, + self.df_spark_store["num_sales"], + ) + dfs_joined = self.df_sqlglot_employee.join( + self.df_sqlglot_store, + on=[ + self.df_sqlglot_employee.store_id == self.df_sqlglot_store.store_id, + self.df_sqlglot_employee.age == self.df_sqlglot_store.num_sales, + ], + how="inner", + ).select( + self.df_sqlglot_employee.employee_id, + self.df_sqlglot_employee["fname"], + SF.col("lname"), + SF.col("age"), + self.df_sqlglot_employee.store_id, + self.df_sqlglot_store.store_name, + self.df_sqlglot_store["num_sales"], + ) + self.compare_spark_with_sqlglot(df_joined, dfs_joined) + + def test_join_inner_equality_multiple_bitwise_and(self): + df_joined = self.df_spark_employee.join( + self.df_spark_store, + on=(self.df_spark_employee.store_id == self.df_spark_store.store_id) + & (self.df_spark_employee.age == self.df_spark_store.num_sales), + how="inner", + ).select( + self.df_spark_employee.employee_id, + self.df_spark_employee["fname"], + F.col("lname"), + F.col("age"), + self.df_spark_employee.store_id, + self.df_spark_store.store_name, + self.df_spark_store["num_sales"], + ) + dfs_joined = self.df_sqlglot_employee.join( + self.df_sqlglot_store, + on=(self.df_sqlglot_employee.store_id == self.df_sqlglot_store.store_id) + & (self.df_sqlglot_employee.age == self.df_sqlglot_store.num_sales), + how="inner", + ).select( + self.df_sqlglot_employee.employee_id, + self.df_sqlglot_employee["fname"], + SF.col("lname"), + SF.col("age"), + self.df_sqlglot_employee.store_id, + self.df_sqlglot_store.store_name, + self.df_sqlglot_store["num_sales"], + ) + self.compare_spark_with_sqlglot(df_joined, dfs_joined) + + def test_join_left_outer(self): + df_joined = ( + self.df_spark_employee.join(self.df_spark_store, on=["store_id"], how="left_outer") + .select( + self.df_spark_employee.employee_id, + self.df_spark_employee["fname"], + F.col("lname"), + F.col("age"), + F.col("store_id"), + self.df_spark_store.store_name, + self.df_spark_store["num_sales"], + ) + .orderBy(F.col("employee_id")) + ) + dfs_joined = ( + self.df_sqlglot_employee.join(self.df_sqlglot_store, on=["store_id"], how="left_outer") + .select( + self.df_sqlglot_employee.employee_id, + self.df_sqlglot_employee["fname"], + SF.col("lname"), + SF.col("age"), + SF.col("store_id"), + self.df_sqlglot_store.store_name, + self.df_sqlglot_store["num_sales"], + ) + .orderBy(SF.col("employee_id")) + ) + self.compare_spark_with_sqlglot(df_joined, dfs_joined) + + def test_join_full_outer(self): + df_joined = self.df_spark_employee.join(self.df_spark_store, on=["store_id"], how="full_outer").select( + self.df_spark_employee.employee_id, + self.df_spark_employee["fname"], + F.col("lname"), + F.col("age"), + F.col("store_id"), + self.df_spark_store.store_name, + self.df_spark_store["num_sales"], + ) + dfs_joined = self.df_sqlglot_employee.join(self.df_sqlglot_store, on=["store_id"], how="full_outer").select( + self.df_sqlglot_employee.employee_id, + self.df_sqlglot_employee["fname"], + SF.col("lname"), + SF.col("age"), + SF.col("store_id"), + self.df_sqlglot_store.store_name, + self.df_sqlglot_store["num_sales"], + ) + self.compare_spark_with_sqlglot(df_joined, dfs_joined) + + def test_triple_join(self): + df = ( + self.df_employee.join(self.df_store, on=self.df_employee.employee_id == self.df_store.store_id) + .join(self.df_district, on=self.df_store.store_id == self.df_district.district_id) + .select( + self.df_employee.employee_id, + self.df_store.store_id, + self.df_district.district_id, + self.df_employee.fname, + self.df_store.store_name, + self.df_district.district_name, + ) + ) + dfs = ( + self.dfs_employee.join(self.dfs_store, on=self.dfs_employee.employee_id == self.dfs_store.store_id) + .join(self.dfs_district, on=self.dfs_store.store_id == self.dfs_district.district_id) + .select( + self.dfs_employee.employee_id, + self.dfs_store.store_id, + self.dfs_district.district_id, + self.dfs_employee.fname, + self.dfs_store.store_name, + self.dfs_district.district_name, + ) + ) + self.compare_spark_with_sqlglot(df, dfs) + + def test_join_select_and_select_start(self): + df = self.df_spark_employee.select(F.col("fname"), F.col("lname"), F.col("age"), F.col("store_id")).join( + self.df_spark_store, "store_id", "inner" + ) + + dfs = self.df_sqlglot_employee.select(SF.col("fname"), SF.col("lname"), SF.col("age"), SF.col("store_id")).join( + self.df_sqlglot_store, "store_id", "inner" + ) + + self.compare_spark_with_sqlglot(df, dfs) + + def test_branching_root_dataframes(self): + """ + Test a pattern that has non-intuitive behavior in spark + + Scenario: You do a self-join in a dataframe using an original dataframe and then a modified version + of it. You then reference the columns by the dataframe name instead of the column function. + Spark will use the root dataframe's column in the result. + """ + df_hydra_employees_only = self.df_spark_employee.where(F.col("store_id") == F.lit(1)) + df_joined = ( + self.df_spark_employee.where(F.col("store_id") == F.lit(2)) + .alias("df_arrow_employees_only") + .join( + df_hydra_employees_only.alias("df_hydra_employees_only"), + on=["store_id"], + how="full_outer", + ) + .select( + self.df_spark_employee.fname, + F.col("df_arrow_employees_only.fname"), + df_hydra_employees_only.fname, + F.col("df_hydra_employees_only.fname"), + ) + ) + + dfs_hydra_employees_only = self.df_sqlglot_employee.where(SF.col("store_id") == SF.lit(1)) + dfs_joined = ( + self.df_sqlglot_employee.where(SF.col("store_id") == SF.lit(2)) + .alias("dfs_arrow_employees_only") + .join( + dfs_hydra_employees_only.alias("dfs_hydra_employees_only"), + on=["store_id"], + how="full_outer", + ) + .select( + self.df_sqlglot_employee.fname, + SF.col("dfs_arrow_employees_only.fname"), + dfs_hydra_employees_only.fname, + SF.col("dfs_hydra_employees_only.fname"), + ) + ) + self.compare_spark_with_sqlglot(df_joined, dfs_joined) + + def test_basic_union(self): + df_unioned = self.df_spark_employee.select(F.col("employee_id"), F.col("age")).union( + self.df_spark_store.select(F.col("store_id"), F.col("num_sales")) + ) + + dfs_unioned = self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("age")).union( + self.df_sqlglot_store.select(SF.col("store_id"), SF.col("num_sales")) + ) + self.compare_spark_with_sqlglot(df_unioned, dfs_unioned) + + def test_union_with_join(self): + df_joined = self.df_spark_employee.join( + self.df_spark_store, + on="store_id", + how="inner", + ) + df_unioned = df_joined.select(F.col("store_id"), F.col("store_name")).union( + self.df_spark_district.select(F.col("district_id"), F.col("district_name")) + ) + + dfs_joined = self.df_sqlglot_employee.join( + self.df_sqlglot_store, + on="store_id", + how="inner", + ) + dfs_unioned = dfs_joined.select(SF.col("store_id"), SF.col("store_name")).union( + self.df_sqlglot_district.select(SF.col("district_id"), SF.col("district_name")) + ) + + self.compare_spark_with_sqlglot(df_unioned, dfs_unioned) + + def test_double_union_all(self): + df_unioned = ( + self.df_spark_employee.select(F.col("employee_id"), F.col("fname")) + .unionAll(self.df_spark_store.select(F.col("store_id"), F.col("store_name"))) + .unionAll(self.df_spark_district.select(F.col("district_id"), F.col("district_name"))) + ) + + dfs_unioned = ( + self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("fname")) + .unionAll(self.df_sqlglot_store.select(SF.col("store_id"), SF.col("store_name"))) + .unionAll(self.df_sqlglot_district.select(SF.col("district_id"), SF.col("district_name"))) + ) + + self.compare_spark_with_sqlglot(df_unioned, dfs_unioned) + + def test_union_by_name(self): + df = self.df_spark_employee.select(F.col("employee_id"), F.col("fname"), F.col("lname")).unionByName( + self.df_spark_store.select( + F.col("store_name").alias("lname"), + F.col("store_id").alias("employee_id"), + F.col("store_name").alias("fname"), + ) + ) + + dfs = self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("fname"), SF.col("lname")).unionByName( + self.df_sqlglot_store.select( + SF.col("store_name").alias("lname"), + SF.col("store_id").alias("employee_id"), + SF.col("store_name").alias("fname"), + ) + ) + + self.compare_spark_with_sqlglot(df, dfs) + + def test_union_by_name_allow_missing(self): + df = self.df_spark_employee.select( + F.col("age"), F.col("employee_id"), F.col("fname"), F.col("lname") + ).unionByName( + self.df_spark_store.select( + F.col("store_name").alias("lname"), + F.col("store_id").alias("employee_id"), + F.col("store_name").alias("fname"), + F.col("num_sales"), + ), + allowMissingColumns=True, + ) + + dfs = self.df_sqlglot_employee.select( + SF.col("age"), SF.col("employee_id"), SF.col("fname"), SF.col("lname") + ).unionByName( + self.df_sqlglot_store.select( + SF.col("store_name").alias("lname"), + SF.col("store_id").alias("employee_id"), + SF.col("store_name").alias("fname"), + SF.col("num_sales"), + ), + allowMissingColumns=True, + ) + + self.compare_spark_with_sqlglot(df, dfs) + + def test_order_by_default(self): + df = self.df_spark_store.groupBy(F.col("district_id")).agg(F.min("num_sales")).orderBy(F.col("district_id")) + + dfs = ( + self.df_sqlglot_store.groupBy(SF.col("district_id")).agg(SF.min("num_sales")).orderBy(SF.col("district_id")) + ) + + self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) + + def test_order_by_array_bool(self): + df = ( + self.df_spark_store.groupBy(F.col("district_id")) + .agg(F.min("num_sales").alias("total_sales")) + .orderBy(F.col("total_sales"), F.col("district_id"), ascending=[1, 0]) + ) + + dfs = ( + self.df_sqlglot_store.groupBy(SF.col("district_id")) + .agg(SF.min("num_sales").alias("total_sales")) + .orderBy(SF.col("total_sales"), SF.col("district_id"), ascending=[1, 0]) + ) + + self.compare_spark_with_sqlglot(df, dfs) + + def test_order_by_single_bool(self): + df = ( + self.df_spark_store.groupBy(F.col("district_id")) + .agg(F.min("num_sales").alias("total_sales")) + .orderBy(F.col("total_sales"), F.col("district_id"), ascending=False) + ) + + dfs = ( + self.df_sqlglot_store.groupBy(SF.col("district_id")) + .agg(SF.min("num_sales").alias("total_sales")) + .orderBy(SF.col("total_sales"), SF.col("district_id"), ascending=False) + ) + + self.compare_spark_with_sqlglot(df, dfs) + + def test_order_by_column_sort_method(self): + df = ( + self.df_spark_store.groupBy(F.col("district_id")) + .agg(F.min("num_sales").alias("total_sales")) + .orderBy(F.col("total_sales").asc(), F.col("district_id").desc()) + ) + + dfs = ( + self.df_sqlglot_store.groupBy(SF.col("district_id")) + .agg(SF.min("num_sales").alias("total_sales")) + .orderBy(SF.col("total_sales").asc(), SF.col("district_id").desc()) + ) + + self.compare_spark_with_sqlglot(df, dfs) + + def test_order_by_column_sort_method_nulls_last(self): + df = ( + self.df_spark_store.groupBy(F.col("district_id")) + .agg(F.min("num_sales").alias("total_sales")) + .orderBy(F.when(F.col("district_id") == F.lit(2), F.col("district_id")).asc_nulls_last()) + ) + + dfs = ( + self.df_sqlglot_store.groupBy(SF.col("district_id")) + .agg(SF.min("num_sales").alias("total_sales")) + .orderBy(SF.when(SF.col("district_id") == SF.lit(2), SF.col("district_id")).asc_nulls_last()) + ) + + self.compare_spark_with_sqlglot(df, dfs) + + def test_order_by_column_sort_method_nulls_first(self): + df = ( + self.df_spark_store.groupBy(F.col("district_id")) + .agg(F.min("num_sales").alias("total_sales")) + .orderBy(F.when(F.col("district_id") == F.lit(1), F.col("district_id")).desc_nulls_first()) + ) + + dfs = ( + self.df_sqlglot_store.groupBy(SF.col("district_id")) + .agg(SF.min("num_sales").alias("total_sales")) + .orderBy(SF.when(SF.col("district_id") == SF.lit(1), SF.col("district_id")).desc_nulls_first()) + ) + + self.compare_spark_with_sqlglot(df, dfs) + + def test_intersect(self): + df_employee_duplicate = self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")).union( + self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")) + ) + + df_store_duplicate = self.df_spark_store.select(F.col("store_id"), F.col("district_id")).union( + self.df_spark_store.select(F.col("store_id"), F.col("district_id")) + ) + + df = df_employee_duplicate.intersect(df_store_duplicate) + + dfs_employee_duplicate = self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")).union( + self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")) + ) + + dfs_store_duplicate = self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")).union( + self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")) + ) + + dfs = dfs_employee_duplicate.intersect(dfs_store_duplicate) + + self.compare_spark_with_sqlglot(df, dfs) + + def test_intersect_all(self): + df_employee_duplicate = self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")).union( + self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")) + ) + + df_store_duplicate = self.df_spark_store.select(F.col("store_id"), F.col("district_id")).union( + self.df_spark_store.select(F.col("store_id"), F.col("district_id")) + ) + + df = df_employee_duplicate.intersectAll(df_store_duplicate) + + dfs_employee_duplicate = self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")).union( + self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")) + ) + + dfs_store_duplicate = self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")).union( + self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")) + ) + + dfs = dfs_employee_duplicate.intersectAll(dfs_store_duplicate) + + self.compare_spark_with_sqlglot(df, dfs) + + def test_except_all(self): + df_employee_duplicate = self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")).union( + self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")) + ) + + df_store_duplicate = self.df_spark_store.select(F.col("store_id"), F.col("district_id")).union( + self.df_spark_store.select(F.col("store_id"), F.col("district_id")) + ) + + df = df_employee_duplicate.exceptAll(df_store_duplicate) + + dfs_employee_duplicate = self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")).union( + self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")) + ) + + dfs_store_duplicate = self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")).union( + self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")) + ) + + dfs = dfs_employee_duplicate.exceptAll(dfs_store_duplicate) + + self.compare_spark_with_sqlglot(df, dfs) + + def test_distinct(self): + df = self.df_spark_employee.select(F.col("age")).distinct() + + dfs = self.df_sqlglot_employee.select(SF.col("age")).distinct() + + self.compare_spark_with_sqlglot(df, dfs) + + def test_union_distinct(self): + df_unioned = ( + self.df_spark_employee.select(F.col("employee_id"), F.col("age")) + .union(self.df_spark_employee.select(F.col("employee_id"), F.col("age"))) + .distinct() + ) + + dfs_unioned = ( + self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("age")) + .union(self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("age"))) + .distinct() + ) + self.compare_spark_with_sqlglot(df_unioned, dfs_unioned) + + def test_drop_duplicates_no_subset(self): + df = self.df_spark_employee.select("age").dropDuplicates() + dfs = self.df_sqlglot_employee.select("age").dropDuplicates() + self.compare_spark_with_sqlglot(df, dfs) + + def test_drop_duplicates_subset(self): + df = self.df_spark_employee.dropDuplicates(["age"]) + dfs = self.df_sqlglot_employee.dropDuplicates(["age"]) + self.compare_spark_with_sqlglot(df, dfs) + + def test_drop_na_default(self): + df = self.df_spark_employee.select(F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")).dropna() + + dfs = self.df_sqlglot_employee.select( + SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age") + ).dropna() + + self.compare_spark_with_sqlglot(df, dfs) + + def test_dropna_how(self): + df = self.df_spark_employee.select( + F.lit(None), F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age") + ).dropna(how="all") + + dfs = self.df_sqlglot_employee.select( + SF.lit(None), SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age") + ).dropna(how="all") + + self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) + + def test_dropna_thresh(self): + df = self.df_spark_employee.select( + F.lit(None), F.lit(1), F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age") + ).dropna(how="any", thresh=2) + + dfs = self.df_sqlglot_employee.select( + SF.lit(None), SF.lit(1), SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age") + ).dropna(how="any", thresh=2) + + self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) + + def test_dropna_subset(self): + df = self.df_spark_employee.select( + F.lit(None), F.lit(1), F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age") + ).dropna(thresh=1, subset="the_age") + + dfs = self.df_sqlglot_employee.select( + SF.lit(None), SF.lit(1), SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age") + ).dropna(thresh=1, subset="the_age") + + self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) + + def test_dropna_na_function(self): + df = self.df_spark_employee.select(F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")).na.drop() + + dfs = self.df_sqlglot_employee.select( + SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age") + ).na.drop() + + self.compare_spark_with_sqlglot(df, dfs) + + def test_fillna_default(self): + df = self.df_spark_employee.select(F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")).fillna(100) + + dfs = self.df_sqlglot_employee.select( + SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age") + ).fillna(100) + + self.compare_spark_with_sqlglot(df, dfs) + + def test_fillna_dict_replacement(self): + df = self.df_spark_employee.select( + F.col("fname"), + F.when(F.col("lname").startswith("L"), F.col("lname")).alias("l_lname"), + F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age"), + ).fillna({"fname": "Jacob", "l_lname": "NOT_LNAME"}) + + dfs = self.df_sqlglot_employee.select( + SF.col("fname"), + SF.when(SF.col("lname").startswith("L"), SF.col("lname")).alias("l_lname"), + SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age"), + ).fillna({"fname": "Jacob", "l_lname": "NOT_LNAME"}) + + # For some reason the sqlglot results sets a column as nullable when it doesn't need to + # This seems to be a nuance in how spark dataframe from sql works so we can ignore + self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) + + def test_fillna_na_func(self): + df = self.df_spark_employee.select(F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")).na.fill(100) + + dfs = self.df_sqlglot_employee.select( + SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age") + ).na.fill(100) + + self.compare_spark_with_sqlglot(df, dfs) + + def test_replace_basic(self): + df = self.df_spark_employee.select(F.col("age"), F.lit(37).alias("test_col")).replace(to_replace=37, value=100) + + dfs = self.df_sqlglot_employee.select(SF.col("age"), SF.lit(37).alias("test_col")).replace( + to_replace=37, value=100 + ) + + self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) + + def test_replace_basic_subset(self): + df = self.df_spark_employee.select(F.col("age"), F.lit(37).alias("test_col")).replace( + to_replace=37, value=100, subset="age" + ) + + dfs = self.df_sqlglot_employee.select(SF.col("age"), SF.lit(37).alias("test_col")).replace( + to_replace=37, value=100, subset="age" + ) + + self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) + + def test_replace_mapping(self): + df = self.df_spark_employee.select(F.col("age"), F.lit(37).alias("test_col")).replace({37: 100}) + + dfs = self.df_sqlglot_employee.select(SF.col("age"), SF.lit(37).alias("test_col")).replace({37: 100}) + + self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) + + def test_replace_mapping_subset(self): + df = self.df_spark_employee.select( + F.col("age"), F.lit(37).alias("test_col"), F.lit(50).alias("test_col_2") + ).replace({37: 100, 50: 1}, subset=["age", "test_col_2"]) + + dfs = self.df_sqlglot_employee.select( + SF.col("age"), SF.lit(37).alias("test_col"), SF.lit(50).alias("test_col_2") + ).replace({37: 100, 50: 1}, subset=["age", "test_col_2"]) + + self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) + + def test_replace_na_func_basic(self): + df = self.df_spark_employee.select(F.col("age"), F.lit(37).alias("test_col")).na.replace( + to_replace=37, value=100 + ) + + dfs = self.df_sqlglot_employee.select(SF.col("age"), SF.lit(37).alias("test_col")).na.replace( + to_replace=37, value=100 + ) + + self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) + + def test_with_column(self): + df = self.df_spark_employee.withColumn("test", F.col("age")) + + dfs = self.df_sqlglot_employee.withColumn("test", SF.col("age")) + + self.compare_spark_with_sqlglot(df, dfs) + + def test_with_column_existing_name(self): + df = self.df_spark_employee.withColumn("fname", F.lit("blah")) + + dfs = self.df_sqlglot_employee.withColumn("fname", SF.lit("blah")) + + self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) + + def test_with_column_renamed(self): + df = self.df_spark_employee.withColumnRenamed("fname", "first_name") + + dfs = self.df_sqlglot_employee.withColumnRenamed("fname", "first_name") + + self.compare_spark_with_sqlglot(df, dfs) + + def test_with_column_renamed_double(self): + df = self.df_spark_employee.select(F.col("fname").alias("first_name")).withColumnRenamed( + "first_name", "first_name_again" + ) + + dfs = self.df_sqlglot_employee.select(SF.col("fname").alias("first_name")).withColumnRenamed( + "first_name", "first_name_again" + ) + + self.compare_spark_with_sqlglot(df, dfs) + + def test_drop_column_single(self): + df = self.df_spark_employee.select(F.col("fname"), F.col("lname"), F.col("age")).drop("age") + + dfs = self.df_sqlglot_employee.select(SF.col("fname"), SF.col("lname"), SF.col("age")).drop("age") + + self.compare_spark_with_sqlglot(df, dfs) + + def test_drop_column_reference_join(self): + df_spark_employee_cols = self.df_spark_employee.select( + F.col("fname"), F.col("lname"), F.col("age"), F.col("store_id") + ) + df_spark_store_cols = self.df_spark_store.select(F.col("store_id"), F.col("store_name")) + df = df_spark_employee_cols.join(df_spark_store_cols, on="store_id", how="inner").drop( + df_spark_employee_cols.age, + ) + + df_sqlglot_employee_cols = self.df_sqlglot_employee.select( + SF.col("fname"), SF.col("lname"), SF.col("age"), SF.col("store_id") + ) + df_sqlglot_store_cols = self.df_sqlglot_store.select(SF.col("store_id"), SF.col("store_name")) + dfs = df_sqlglot_employee_cols.join(df_sqlglot_store_cols, on="store_id", how="inner").drop( + df_sqlglot_employee_cols.age, + ) + + self.compare_spark_with_sqlglot(df, dfs) + + def test_limit(self): + df = self.df_spark_employee.limit(1) + + dfs = self.df_sqlglot_employee.limit(1) + + self.compare_spark_with_sqlglot(df, dfs) + + def test_hint_broadcast_alias(self): + df_joined = self.df_spark_employee.join( + self.df_spark_store.alias("store").hint("broadcast", "store"), + on=self.df_spark_employee.store_id == self.df_spark_store.store_id, + how="inner", + ).select( + self.df_spark_employee.employee_id, + self.df_spark_employee["fname"], + F.col("lname"), + F.col("age"), + self.df_spark_employee.store_id, + self.df_spark_store.store_name, + self.df_spark_store["num_sales"], + ) + dfs_joined = self.df_sqlglot_employee.join( + self.df_sqlglot_store.alias("store").hint("broadcast", "store"), + on=self.df_sqlglot_employee.store_id == self.df_sqlglot_store.store_id, + how="inner", + ).select( + self.df_sqlglot_employee.employee_id, + self.df_sqlglot_employee["fname"], + SF.col("lname"), + SF.col("age"), + self.df_sqlglot_employee.store_id, + self.df_sqlglot_store.store_name, + self.df_sqlglot_store["num_sales"], + ) + df, dfs = self.compare_spark_with_sqlglot(df_joined, dfs_joined) + self.assertIn("ResolvedHint (strategy=broadcast)", self.get_explain_plan(df)) + self.assertIn("ResolvedHint (strategy=broadcast)", self.get_explain_plan(dfs)) + + def test_hint_broadcast_no_alias(self): + df_joined = self.df_spark_employee.join( + self.df_spark_store.hint("broadcast"), + on=self.df_spark_employee.store_id == self.df_spark_store.store_id, + how="inner", + ).select( + self.df_spark_employee.employee_id, + self.df_spark_employee["fname"], + F.col("lname"), + F.col("age"), + self.df_spark_employee.store_id, + self.df_spark_store.store_name, + self.df_spark_store["num_sales"], + ) + dfs_joined = self.df_sqlglot_employee.join( + self.df_sqlglot_store.hint("broadcast"), + on=self.df_sqlglot_employee.store_id == self.df_sqlglot_store.store_id, + how="inner", + ).select( + self.df_sqlglot_employee.employee_id, + self.df_sqlglot_employee["fname"], + SF.col("lname"), + SF.col("age"), + self.df_sqlglot_employee.store_id, + self.df_sqlglot_store.store_name, + self.df_sqlglot_store["num_sales"], + ) + df, dfs = self.compare_spark_with_sqlglot(df_joined, dfs_joined) + self.assertIn("ResolvedHint (strategy=broadcast)", self.get_explain_plan(df)) + self.assertIn("ResolvedHint (strategy=broadcast)", self.get_explain_plan(dfs)) + + # TODO: Add test to make sure with and without alias are the same once ids are deterministic + + def test_broadcast_func(self): + df_joined = self.df_spark_employee.join( + F.broadcast(self.df_spark_store), + on=self.df_spark_employee.store_id == self.df_spark_store.store_id, + how="inner", + ).select( + self.df_spark_employee.employee_id, + self.df_spark_employee["fname"], + F.col("lname"), + F.col("age"), + self.df_spark_employee.store_id, + self.df_spark_store.store_name, + self.df_spark_store["num_sales"], + ) + dfs_joined = self.df_sqlglot_employee.join( + SF.broadcast(self.df_sqlglot_store), + on=self.df_sqlglot_employee.store_id == self.df_sqlglot_store.store_id, + how="inner", + ).select( + self.df_sqlglot_employee.employee_id, + self.df_sqlglot_employee["fname"], + SF.col("lname"), + SF.col("age"), + self.df_sqlglot_employee.store_id, + self.df_sqlglot_store.store_name, + self.df_sqlglot_store["num_sales"], + ) + df, dfs = self.compare_spark_with_sqlglot(df_joined, dfs_joined) + self.assertIn("ResolvedHint (strategy=broadcast)", self.get_explain_plan(df)) + self.assertIn("ResolvedHint (strategy=broadcast)", self.get_explain_plan(dfs)) + + def test_repartition_by_num(self): + """ + The results are different when doing the repartition on a table created using VALUES in SQL. + So I just use the views instead for these tests + """ + df = self.df_spark_employee.repartition(63) + + dfs = self.sqlglot.read.table("employee").repartition(63) + df, dfs = self.compare_spark_with_sqlglot(df, dfs) + spark_num_partitions = df.rdd.getNumPartitions() + sqlglot_num_partitions = dfs.rdd.getNumPartitions() + self.assertEqual(spark_num_partitions, 63) + self.assertEqual(spark_num_partitions, sqlglot_num_partitions) + + def test_repartition_name_only(self): + """ + We use the view here to help ensure the explain plans are similar enough to compare + """ + df = self.df_spark_employee.repartition("age") + + dfs = self.sqlglot.read.table("employee").repartition("age") + df, dfs = self.compare_spark_with_sqlglot(df, dfs) + self.assertIn("RepartitionByExpression [age", self.get_explain_plan(df)) + self.assertIn("RepartitionByExpression [age", self.get_explain_plan(dfs)) + + def test_repartition_num_and_multiple_names(self): + """ + We use the view here to help ensure the explain plans are similar enough to compare + """ + df = self.df_spark_employee.repartition(53, "age", "fname") + + dfs = self.sqlglot.read.table("employee").repartition(53, "age", "fname") + df, dfs = self.compare_spark_with_sqlglot(df, dfs) + spark_num_partitions = df.rdd.getNumPartitions() + sqlglot_num_partitions = dfs.rdd.getNumPartitions() + self.assertEqual(spark_num_partitions, 53) + self.assertEqual(spark_num_partitions, sqlglot_num_partitions) + self.assertIn("RepartitionByExpression [age#3, fname#1], 53", self.get_explain_plan(df)) + self.assertIn("RepartitionByExpression [age#3, fname#1], 53", self.get_explain_plan(dfs)) + + def test_coalesce(self): + df = self.df_spark_employee.coalesce(1) + dfs = self.df_sqlglot_employee.coalesce(1) + df, dfs = self.compare_spark_with_sqlglot(df, dfs) + spark_num_partitions = df.rdd.getNumPartitions() + sqlglot_num_partitions = dfs.rdd.getNumPartitions() + self.assertEqual(spark_num_partitions, 1) + self.assertEqual(spark_num_partitions, sqlglot_num_partitions) + + def test_cache_select(self): + df_employee = ( + self.df_spark_employee.groupBy("store_id") + .agg(F.countDistinct("employee_id").alias("num_employees")) + .cache() + ) + df_joined = df_employee.join(self.df_spark_store, on="store_id").select( + self.df_spark_store.store_id, df_employee.num_employees + ) + dfs_employee = ( + self.df_sqlglot_employee.groupBy("store_id") + .agg(SF.countDistinct("employee_id").alias("num_employees")) + .cache() + ) + dfs_joined = dfs_employee.join(self.df_sqlglot_store, on="store_id").select( + self.df_sqlglot_store.store_id, dfs_employee.num_employees + ) + self.compare_spark_with_sqlglot(df_joined, dfs_joined) + + def test_persist_select(self): + df_employee = ( + self.df_spark_employee.groupBy("store_id") + .agg(F.countDistinct("employee_id").alias("num_employees")) + .persist() + ) + df_joined = df_employee.join(self.df_spark_store, on="store_id").select( + self.df_spark_store.store_id, df_employee.num_employees + ) + dfs_employee = ( + self.df_sqlglot_employee.groupBy("store_id") + .agg(SF.countDistinct("employee_id").alias("num_employees")) + .persist() + ) + dfs_joined = dfs_employee.join(self.df_sqlglot_store, on="store_id").select( + self.df_sqlglot_store.store_id, dfs_employee.num_employees + ) + self.compare_spark_with_sqlglot(df_joined, dfs_joined) diff --git a/tests/dataframe/integration/test_grouped_data.py b/tests/dataframe/integration/test_grouped_data.py new file mode 100644 index 0000000..2768dda --- /dev/null +++ b/tests/dataframe/integration/test_grouped_data.py @@ -0,0 +1,71 @@ +from pyspark.sql import functions as F + +from sqlglot.dataframe.sql import functions as SF +from tests.dataframe.integration.dataframe_validator import DataFrameValidator + + +class TestDataframeFunc(DataFrameValidator): + def test_group_by(self): + df_employee = self.df_spark_employee.groupBy(self.df_spark_employee.age).agg( + F.min(self.df_spark_employee.employee_id) + ) + dfs_employee = self.df_sqlglot_employee.groupBy(self.df_sqlglot_employee.age).agg( + SF.min(self.df_sqlglot_employee.employee_id) + ) + self.compare_spark_with_sqlglot(df_employee, dfs_employee, skip_schema_compare=True) + + def test_group_by_where_non_aggregate(self): + df_employee = ( + self.df_spark_employee.groupBy(self.df_spark_employee.age) + .agg(F.min(self.df_spark_employee.employee_id).alias("min_employee_id")) + .where(F.col("age") > F.lit(50)) + ) + dfs_employee = ( + self.df_sqlglot_employee.groupBy(self.df_sqlglot_employee.age) + .agg(SF.min(self.df_sqlglot_employee.employee_id).alias("min_employee_id")) + .where(SF.col("age") > SF.lit(50)) + ) + self.compare_spark_with_sqlglot(df_employee, dfs_employee) + + def test_group_by_where_aggregate_like_having(self): + df_employee = ( + self.df_spark_employee.groupBy(self.df_spark_employee.age) + .agg(F.min(self.df_spark_employee.employee_id).alias("min_employee_id")) + .where(F.col("min_employee_id") > F.lit(1)) + ) + dfs_employee = ( + self.df_sqlglot_employee.groupBy(self.df_sqlglot_employee.age) + .agg(SF.min(self.df_sqlglot_employee.employee_id).alias("min_employee_id")) + .where(SF.col("min_employee_id") > SF.lit(1)) + ) + self.compare_spark_with_sqlglot(df_employee, dfs_employee) + + def test_count(self): + df = self.df_spark_employee.groupBy(self.df_spark_employee.age).count() + dfs = self.df_sqlglot_employee.groupBy(self.df_sqlglot_employee.age).count() + self.compare_spark_with_sqlglot(df, dfs) + + def test_mean(self): + df = self.df_spark_employee.groupBy().mean("age", "store_id") + dfs = self.df_sqlglot_employee.groupBy().mean("age", "store_id") + self.compare_spark_with_sqlglot(df, dfs) + + def test_avg(self): + df = self.df_spark_employee.groupBy("age").avg("store_id") + dfs = self.df_sqlglot_employee.groupBy("age").avg("store_id") + self.compare_spark_with_sqlglot(df, dfs) + + def test_max(self): + df = self.df_spark_employee.groupBy("age").max("store_id") + dfs = self.df_sqlglot_employee.groupBy("age").max("store_id") + self.compare_spark_with_sqlglot(df, dfs) + + def test_min(self): + df = self.df_spark_employee.groupBy("age").min("store_id") + dfs = self.df_sqlglot_employee.groupBy("age").min("store_id") + self.compare_spark_with_sqlglot(df, dfs) + + def test_sum(self): + df = self.df_spark_employee.groupBy("age").sum("store_id") + dfs = self.df_sqlglot_employee.groupBy("age").sum("store_id") + self.compare_spark_with_sqlglot(df, dfs) diff --git a/tests/dataframe/integration/test_session.py b/tests/dataframe/integration/test_session.py new file mode 100644 index 0000000..ff1477b --- /dev/null +++ b/tests/dataframe/integration/test_session.py @@ -0,0 +1,28 @@ +from pyspark.sql import functions as F + +from sqlglot.dataframe.sql import functions as SF +from tests.dataframe.integration.dataframe_validator import DataFrameValidator + + +class TestSessionFunc(DataFrameValidator): + def test_sql_simple_select(self): + query = "SELECT fname, lname FROM employee" + df = self.spark.sql(query) + dfs = self.sqlglot.sql(query) + self.compare_spark_with_sqlglot(df, dfs) + + def test_sql_with_join(self): + query = """ + SELECT + e.employee_id + , s.store_id + FROM + employee e + INNER JOIN + store s + ON + e.store_id = s.store_id + """ + df = self.spark.sql(query).groupBy(F.col("store_id")).agg(F.countDistinct(F.col("employee_id"))) + dfs = self.sqlglot.sql(query).groupBy(SF.col("store_id")).agg(SF.countDistinct(SF.col("employee_id"))) + self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) diff --git a/tests/dataframe/unit/__init__.py b/tests/dataframe/unit/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/dataframe/unit/dataframe_sql_validator.py b/tests/dataframe/unit/dataframe_sql_validator.py new file mode 100644 index 0000000..fc56553 --- /dev/null +++ b/tests/dataframe/unit/dataframe_sql_validator.py @@ -0,0 +1,35 @@ +import typing as t +import unittest + +from sqlglot.dataframe.sql import types +from sqlglot.dataframe.sql.dataframe import DataFrame +from sqlglot.dataframe.sql.session import SparkSession + + +class DataFrameSQLValidator(unittest.TestCase): + def setUp(self) -> None: + self.spark = SparkSession() + self.employee_schema = types.StructType( + [ + types.StructField("employee_id", types.IntegerType(), False), + types.StructField("fname", types.StringType(), False), + types.StructField("lname", types.StringType(), False), + types.StructField("age", types.IntegerType(), False), + types.StructField("store_id", types.IntegerType(), False), + ] + ) + employee_data = [ + (1, "Jack", "Shephard", 37, 1), + (2, "John", "Locke", 65, 1), + (3, "Kate", "Austen", 37, 2), + (4, "Claire", "Littleton", 27, 2), + (5, "Hugo", "Reyes", 29, 100), + ] + self.df_employee = self.spark.createDataFrame(data=employee_data, schema=self.employee_schema) + + def compare_sql(self, df: DataFrame, expected_statements: t.Union[str, t.List[str]], pretty=False): + actual_sqls = df.sql(pretty=pretty) + expected_statements = [expected_statements] if isinstance(expected_statements, str) else expected_statements + self.assertEqual(len(expected_statements), len(actual_sqls)) + for expected, actual in zip(expected_statements, actual_sqls): + self.assertEqual(expected, actual) diff --git a/tests/dataframe/unit/test_column.py b/tests/dataframe/unit/test_column.py new file mode 100644 index 0000000..df0ebff --- /dev/null +++ b/tests/dataframe/unit/test_column.py @@ -0,0 +1,167 @@ +import datetime +import unittest + +from sqlglot.dataframe.sql import functions as F +from sqlglot.dataframe.sql.window import Window + + +class TestDataframeColumn(unittest.TestCase): + def test_eq(self): + self.assertEqual("cola = 1", (F.col("cola") == 1).sql()) + + def test_neq(self): + self.assertEqual("cola <> 1", (F.col("cola") != 1).sql()) + + def test_gt(self): + self.assertEqual("cola > 1", (F.col("cola") > 1).sql()) + + def test_lt(self): + self.assertEqual("cola < 1", (F.col("cola") < 1).sql()) + + def test_le(self): + self.assertEqual("cola <= 1", (F.col("cola") <= 1).sql()) + + def test_ge(self): + self.assertEqual("cola >= 1", (F.col("cola") >= 1).sql()) + + def test_and(self): + self.assertEqual( + "cola = colb AND colc = cold", ((F.col("cola") == F.col("colb")) & (F.col("colc") == F.col("cold"))).sql() + ) + + def test_or(self): + self.assertEqual( + "cola = colb OR colc = cold", ((F.col("cola") == F.col("colb")) | (F.col("colc") == F.col("cold"))).sql() + ) + + def test_mod(self): + self.assertEqual("cola % 2", (F.col("cola") % 2).sql()) + + def test_add(self): + self.assertEqual("cola + 1", (F.col("cola") + 1).sql()) + + def test_sub(self): + self.assertEqual("cola - 1", (F.col("cola") - 1).sql()) + + def test_mul(self): + self.assertEqual("cola * 2", (F.col("cola") * 2).sql()) + + def test_div(self): + self.assertEqual("cola / 2", (F.col("cola") / 2).sql()) + + def test_radd(self): + self.assertEqual("1 + cola", (1 + F.col("cola")).sql()) + + def test_rsub(self): + self.assertEqual("1 - cola", (1 - F.col("cola")).sql()) + + def test_rmul(self): + self.assertEqual("1 * cola", (1 * F.col("cola")).sql()) + + def test_rdiv(self): + self.assertEqual("1 / cola", (1 / F.col("cola")).sql()) + + def test_pow(self): + self.assertEqual("POWER(cola, 2)", (F.col("cola") ** 2).sql()) + + def test_rpow(self): + self.assertEqual("POWER(2, cola)", (2 ** F.col("cola")).sql()) + + def test_invert(self): + self.assertEqual("NOT cola", (~F.col("cola")).sql()) + + def test_startswith(self): + self.assertEqual("STARTSWITH(cola, 'test')", F.col("cola").startswith("test").sql()) + + def test_endswith(self): + self.assertEqual("ENDSWITH(cola, 'test')", F.col("cola").endswith("test").sql()) + + def test_rlike(self): + self.assertEqual("cola RLIKE 'foo'", F.col("cola").rlike("foo").sql()) + + def test_like(self): + self.assertEqual("cola LIKE 'foo%'", F.col("cola").like("foo%").sql()) + + def test_ilike(self): + self.assertEqual("cola ILIKE 'foo%'", F.col("cola").ilike("foo%").sql()) + + def test_substring(self): + self.assertEqual("SUBSTRING(cola, 2, 3)", F.col("cola").substr(2, 3).sql()) + + def test_isin(self): + self.assertEqual("cola IN (1, 2, 3)", F.col("cola").isin([1, 2, 3]).sql()) + self.assertEqual("cola IN (1, 2, 3)", F.col("cola").isin(1, 2, 3).sql()) + + def test_asc(self): + self.assertEqual("cola", F.col("cola").asc().sql()) + + def test_desc(self): + self.assertEqual("cola DESC", F.col("cola").desc().sql()) + + def test_asc_nulls_first(self): + self.assertEqual("cola", F.col("cola").asc_nulls_first().sql()) + + def test_asc_nulls_last(self): + self.assertEqual("cola NULLS LAST", F.col("cola").asc_nulls_last().sql()) + + def test_desc_nulls_first(self): + self.assertEqual("cola DESC NULLS FIRST", F.col("cola").desc_nulls_first().sql()) + + def test_desc_nulls_last(self): + self.assertEqual("cola DESC", F.col("cola").desc_nulls_last().sql()) + + def test_when_otherwise(self): + self.assertEqual("CASE WHEN cola = 1 THEN 2 END", F.when(F.col("cola") == 1, 2).sql()) + self.assertEqual("CASE WHEN cola = 1 THEN 2 END", F.col("cola").when(F.col("cola") == 1, 2).sql()) + self.assertEqual( + "CASE WHEN cola = 1 THEN 2 WHEN colb = 2 THEN 3 END", + (F.when(F.col("cola") == 1, 2).when(F.col("colb") == 2, 3)).sql(), + ) + self.assertEqual( + "CASE WHEN cola = 1 THEN 2 WHEN colb = 2 THEN 3 END", + F.col("cola").when(F.col("cola") == 1, 2).when(F.col("colb") == 2, 3).sql(), + ) + self.assertEqual( + "CASE WHEN cola = 1 THEN 2 WHEN colb = 2 THEN 3 ELSE 4 END", + F.when(F.col("cola") == 1, 2).when(F.col("colb") == 2, 3).otherwise(4).sql(), + ) + + def test_is_null(self): + self.assertEqual("cola IS NULL", F.col("cola").isNull().sql()) + + def test_is_not_null(self): + self.assertEqual("NOT cola IS NULL", F.col("cola").isNotNull().sql()) + + def test_cast(self): + self.assertEqual("CAST(cola AS INT)", F.col("cola").cast("INT").sql()) + + def test_alias(self): + self.assertEqual("cola AS new_name", F.col("cola").alias("new_name").sql()) + + def test_between(self): + self.assertEqual("cola BETWEEN 1 AND 3", F.col("cola").between(1, 3).sql()) + self.assertEqual("cola BETWEEN 10.1 AND 12.1", F.col("cola").between(10.1, 12.1).sql()) + self.assertEqual( + "cola BETWEEN TO_DATE('2022-01-01') AND TO_DATE('2022-03-01')", + F.col("cola").between(datetime.date(2022, 1, 1), datetime.date(2022, 3, 1)).sql(), + ) + self.assertEqual( + "cola BETWEEN CAST('2022-01-01 01:01:01' AS TIMESTAMP) " "AND CAST('2022-03-01 01:01:01' AS TIMESTAMP)", + F.col("cola").between(datetime.datetime(2022, 1, 1, 1, 1, 1), datetime.datetime(2022, 3, 1, 1, 1, 1)).sql(), + ) + + def test_over(self): + over_rows = F.sum("cola").over( + Window.partitionBy("colb").orderBy("colc").rowsBetween(1, Window.unboundedFollowing) + ) + self.assertEqual( + "SUM(cola) OVER (PARTITION BY colb ORDER BY colc ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)", + over_rows.sql(), + ) + over_range = F.sum("cola").over( + Window.partitionBy("colb").orderBy("colc").rangeBetween(1, Window.unboundedFollowing) + ) + self.assertEqual( + "SUM(cola) OVER (PARTITION BY colb ORDER BY colc RANGE BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)", + over_range.sql(), + ) diff --git a/tests/dataframe/unit/test_dataframe.py b/tests/dataframe/unit/test_dataframe.py new file mode 100644 index 0000000..c222cac --- /dev/null +++ b/tests/dataframe/unit/test_dataframe.py @@ -0,0 +1,39 @@ +from sqlglot import expressions as exp +from sqlglot.dataframe.sql.dataframe import DataFrame +from tests.dataframe.unit.dataframe_sql_validator import DataFrameSQLValidator + + +class TestDataframe(DataFrameSQLValidator): + def test_hash_select_expression(self): + expression = exp.select("cola").from_("table") + self.assertEqual("t17051", DataFrame._create_hash_from_expression(expression)) + + def test_columns(self): + self.assertEqual(["employee_id", "fname", "lname", "age", "store_id"], self.df_employee.columns) + + def test_cache(self): + df = self.df_employee.select("fname").cache() + expected_statements = [ + "DROP VIEW IF EXISTS t11623", + "CACHE LAZY TABLE t11623 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT CAST(`a1`.`fname` AS string) AS `fname` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)", + "SELECT `t11623`.`fname` AS `fname` FROM `t11623` AS `t11623`", + ] + self.compare_sql(df, expected_statements) + + def test_persist_default(self): + df = self.df_employee.select("fname").persist() + expected_statements = [ + "DROP VIEW IF EXISTS t11623", + "CACHE LAZY TABLE t11623 OPTIONS('storageLevel' = 'MEMORY_AND_DISK_SER') AS SELECT CAST(`a1`.`fname` AS string) AS `fname` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)", + "SELECT `t11623`.`fname` AS `fname` FROM `t11623` AS `t11623`", + ] + self.compare_sql(df, expected_statements) + + def test_persist_storagelevel(self): + df = self.df_employee.select("fname").persist("DISK_ONLY_2") + expected_statements = [ + "DROP VIEW IF EXISTS t11623", + "CACHE LAZY TABLE t11623 OPTIONS('storageLevel' = 'DISK_ONLY_2') AS SELECT CAST(`a1`.`fname` AS string) AS `fname` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)", + "SELECT `t11623`.`fname` AS `fname` FROM `t11623` AS `t11623`", + ] + self.compare_sql(df, expected_statements) diff --git a/tests/dataframe/unit/test_dataframe_writer.py b/tests/dataframe/unit/test_dataframe_writer.py new file mode 100644 index 0000000..14b4a0a --- /dev/null +++ b/tests/dataframe/unit/test_dataframe_writer.py @@ -0,0 +1,86 @@ +from unittest import mock + +import sqlglot +from sqlglot.schema import MappingSchema +from tests.dataframe.unit.dataframe_sql_validator import DataFrameSQLValidator + + +class TestDataFrameWriter(DataFrameSQLValidator): + def test_insertInto_full_path(self): + df = self.df_employee.write.insertInto("catalog.db.table_name") + expected = "INSERT INTO catalog.db.table_name SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" + self.compare_sql(df, expected) + + def test_insertInto_db_table(self): + df = self.df_employee.write.insertInto("db.table_name") + expected = "INSERT INTO db.table_name SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" + self.compare_sql(df, expected) + + def test_insertInto_table(self): + df = self.df_employee.write.insertInto("table_name") + expected = "INSERT INTO table_name SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" + self.compare_sql(df, expected) + + def test_insertInto_overwrite(self): + df = self.df_employee.write.insertInto("table_name", overwrite=True) + expected = "INSERT OVERWRITE TABLE table_name SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" + self.compare_sql(df, expected) + + @mock.patch("sqlglot.schema", MappingSchema()) + def test_insertInto_byName(self): + sqlglot.schema.add_table("table_name", {"employee_id": "INT"}) + df = self.df_employee.write.byName.insertInto("table_name") + expected = "INSERT INTO table_name SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" + self.compare_sql(df, expected) + + def test_insertInto_cache(self): + df = self.df_employee.cache().write.insertInto("table_name") + expected_statements = [ + "DROP VIEW IF EXISTS t35612", + "CACHE LAZY TABLE t35612 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)", + "INSERT INTO table_name SELECT `t35612`.`employee_id` AS `employee_id`, `t35612`.`fname` AS `fname`, `t35612`.`lname` AS `lname`, `t35612`.`age` AS `age`, `t35612`.`store_id` AS `store_id` FROM `t35612` AS `t35612`", + ] + self.compare_sql(df, expected_statements) + + def test_saveAsTable_format(self): + with self.assertRaises(NotImplementedError): + self.df_employee.write.saveAsTable("table_name", format="parquet").sql(pretty=False)[0] + + def test_saveAsTable_append(self): + df = self.df_employee.write.saveAsTable("table_name", mode="append") + expected = "INSERT INTO table_name SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" + self.compare_sql(df, expected) + + def test_saveAsTable_overwrite(self): + df = self.df_employee.write.saveAsTable("table_name", mode="overwrite") + expected = "CREATE OR REPLACE TABLE table_name AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" + self.compare_sql(df, expected) + + def test_saveAsTable_error(self): + df = self.df_employee.write.saveAsTable("table_name", mode="error") + expected = "CREATE TABLE table_name AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" + self.compare_sql(df, expected) + + def test_saveAsTable_ignore(self): + df = self.df_employee.write.saveAsTable("table_name", mode="ignore") + expected = "CREATE TABLE IF NOT EXISTS table_name AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" + self.compare_sql(df, expected) + + def test_mode_standalone(self): + df = self.df_employee.write.mode("ignore").saveAsTable("table_name") + expected = "CREATE TABLE IF NOT EXISTS table_name AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" + self.compare_sql(df, expected) + + def test_mode_override(self): + df = self.df_employee.write.mode("ignore").saveAsTable("table_name", mode="overwrite") + expected = "CREATE OR REPLACE TABLE table_name AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" + self.compare_sql(df, expected) + + def test_saveAsTable_cache(self): + df = self.df_employee.cache().write.saveAsTable("table_name") + expected_statements = [ + "DROP VIEW IF EXISTS t35612", + "CACHE LAZY TABLE t35612 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)", + "CREATE TABLE table_name AS SELECT `t35612`.`employee_id` AS `employee_id`, `t35612`.`fname` AS `fname`, `t35612`.`lname` AS `lname`, `t35612`.`age` AS `age`, `t35612`.`store_id` AS `store_id` FROM `t35612` AS `t35612`", + ] + self.compare_sql(df, expected_statements) diff --git a/tests/dataframe/unit/test_functions.py b/tests/dataframe/unit/test_functions.py new file mode 100644 index 0000000..10f3b57 --- /dev/null +++ b/tests/dataframe/unit/test_functions.py @@ -0,0 +1,1593 @@ +import datetime +import inspect +import unittest + +from sqlglot import expressions as exp +from sqlglot import parse_one +from sqlglot.dataframe.sql import functions as SF +from sqlglot.errors import ErrorLevel + + +class TestFunctions(unittest.TestCase): + @unittest.skip("not yet fixed.") + def test_invoke_anonymous(self): + for name, func in inspect.getmembers(SF, inspect.isfunction): + with self.subTest(f"{name} should not invoke anonymous_function"): + if "invoke_anonymous_function" in inspect.getsource(func): + func = parse_one(f"{name}()", read="spark", error_level=ErrorLevel.IGNORE) + self.assertIsInstance(func, exp.Anonymous) + + def test_lit(self): + test_str = SF.lit("test") + self.assertEqual("'test'", test_str.sql()) + test_int = SF.lit(30) + self.assertEqual("30", test_int.sql()) + test_float = SF.lit(10.10) + self.assertEqual("10.1", test_float.sql()) + test_bool = SF.lit(False) + self.assertEqual("FALSE", test_bool.sql()) + test_null = SF.lit(None) + self.assertEqual("NULL", test_null.sql()) + test_date = SF.lit(datetime.date(2022, 1, 1)) + self.assertEqual("TO_DATE('2022-01-01')", test_date.sql()) + test_datetime = SF.lit(datetime.datetime(2022, 1, 1, 1, 1, 1)) + self.assertEqual("CAST('2022-01-01 01:01:01' AS TIMESTAMP)", test_datetime.sql()) + test_dict = SF.lit({"cola": 1, "colb": "test"}) + self.assertEqual("STRUCT(1 AS cola, 'test' AS colb)", test_dict.sql()) + + def test_col(self): + test_col = SF.col("cola") + self.assertEqual("cola", test_col.sql()) + test_col_with_table = SF.col("table.cola") + self.assertEqual("table.cola", test_col_with_table.sql()) + test_col_on_col = SF.col(test_col) + self.assertEqual("cola", test_col_on_col.sql()) + test_int = SF.col(10) + self.assertEqual("10", test_int.sql()) + test_float = SF.col(10.10) + self.assertEqual("10.1", test_float.sql()) + test_bool = SF.col(True) + self.assertEqual("TRUE", test_bool.sql()) + test_array = SF.col([1, 2, "3"]) + self.assertEqual("ARRAY(1, 2, '3')", test_array.sql()) + test_date = SF.col(datetime.date(2022, 1, 1)) + self.assertEqual("TO_DATE('2022-01-01')", test_date.sql()) + test_datetime = SF.col(datetime.datetime(2022, 1, 1, 1, 1, 1)) + self.assertEqual("CAST('2022-01-01 01:01:01' AS TIMESTAMP)", test_datetime.sql()) + test_dict = SF.col({"cola": 1, "colb": "test"}) + self.assertEqual("STRUCT(1 AS cola, 'test' AS colb)", test_dict.sql()) + + def test_asc(self): + asc_str = SF.asc("cola") + # ASC is removed from output since that is default so we can't check sql + self.assertIsInstance(asc_str.expression, exp.Ordered) + asc_col = SF.asc(SF.col("cola")) + self.assertIsInstance(asc_col.expression, exp.Ordered) + + def test_desc(self): + desc_str = SF.desc("cola") + self.assertEqual("cola DESC", desc_str.sql()) + desc_col = SF.desc(SF.col("cola")) + self.assertEqual("cola DESC", desc_col.sql()) + + def test_sqrt(self): + col_str = SF.sqrt("cola") + self.assertEqual("SQRT(cola)", col_str.sql()) + col = SF.sqrt(SF.col("cola")) + self.assertEqual("SQRT(cola)", col.sql()) + + def test_abs(self): + col_str = SF.abs("cola") + self.assertEqual("ABS(cola)", col_str.sql()) + col = SF.abs(SF.col("cola")) + self.assertEqual("ABS(cola)", col.sql()) + + def test_max(self): + col_str = SF.max("cola") + self.assertEqual("MAX(cola)", col_str.sql()) + col = SF.max(SF.col("cola")) + self.assertEqual("MAX(cola)", col.sql()) + + def test_min(self): + col_str = SF.min("cola") + self.assertEqual("MIN(cola)", col_str.sql()) + col = SF.min(SF.col("cola")) + self.assertEqual("MIN(cola)", col.sql()) + + def test_max_by(self): + col_str = SF.max_by("cola", "colb") + self.assertEqual("MAX_BY(cola, colb)", col_str.sql()) + col = SF.max_by(SF.col("cola"), SF.col("colb")) + self.assertEqual("MAX_BY(cola, colb)", col.sql()) + + def test_min_by(self): + col_str = SF.min_by("cola", "colb") + self.assertEqual("MIN_BY(cola, colb)", col_str.sql()) + col = SF.min_by(SF.col("cola"), SF.col("colb")) + self.assertEqual("MIN_BY(cola, colb)", col.sql()) + + def test_count(self): + col_str = SF.count("cola") + self.assertEqual("COUNT(cola)", col_str.sql()) + col = SF.count(SF.col("cola")) + self.assertEqual("COUNT(cola)", col.sql()) + + def test_sum(self): + col_str = SF.sum("cola") + self.assertEqual("SUM(cola)", col_str.sql()) + col = SF.sum(SF.col("cola")) + self.assertEqual("SUM(cola)", col.sql()) + + def test_avg(self): + col_str = SF.avg("cola") + self.assertEqual("AVG(cola)", col_str.sql()) + col = SF.avg(SF.col("cola")) + self.assertEqual("AVG(cola)", col.sql()) + + def test_mean(self): + col_str = SF.mean("cola") + self.assertEqual("MEAN(cola)", col_str.sql()) + col = SF.mean(SF.col("cola")) + self.assertEqual("MEAN(cola)", col.sql()) + + def test_sum_distinct(self): + with self.assertRaises(NotImplementedError): + SF.sum_distinct("cola") + with self.assertRaises(NotImplementedError): + SF.sumDistinct("cola") + + def test_product(self): + with self.assertRaises(NotImplementedError): + SF.product("cola") + with self.assertRaises(NotImplementedError): + SF.product("cola") + + def test_acos(self): + col_str = SF.acos("cola") + self.assertEqual("ACOS(cola)", col_str.sql()) + col = SF.acos(SF.col("cola")) + self.assertEqual("ACOS(cola)", col.sql()) + + def test_acosh(self): + col_str = SF.acosh("cola") + self.assertEqual("ACOSH(cola)", col_str.sql()) + col = SF.acosh(SF.col("cola")) + self.assertEqual("ACOSH(cola)", col.sql()) + + def test_asin(self): + col_str = SF.asin("cola") + self.assertEqual("ASIN(cola)", col_str.sql()) + col = SF.asin(SF.col("cola")) + self.assertEqual("ASIN(cola)", col.sql()) + + def test_asinh(self): + col_str = SF.asinh("cola") + self.assertEqual("ASINH(cola)", col_str.sql()) + col = SF.asinh(SF.col("cola")) + self.assertEqual("ASINH(cola)", col.sql()) + + def test_atan(self): + col_str = SF.atan("cola") + self.assertEqual("ATAN(cola)", col_str.sql()) + col = SF.atan(SF.col("cola")) + self.assertEqual("ATAN(cola)", col.sql()) + + def test_atan2(self): + col_str = SF.atan2("cola", "colb") + self.assertEqual("ATAN2(cola, colb)", col_str.sql()) + col = SF.atan2(SF.col("cola"), SF.col("colb")) + self.assertEqual("ATAN2(cola, colb)", col.sql()) + col_float = SF.atan2(10.10, "colb") + self.assertEqual("ATAN2(10.1, colb)", col_float.sql()) + col_float2 = SF.atan2("cola", 10.10) + self.assertEqual("ATAN2(cola, 10.1)", col_float2.sql()) + + def test_atanh(self): + col_str = SF.atanh("cola") + self.assertEqual("ATANH(cola)", col_str.sql()) + col = SF.atanh(SF.col("cola")) + self.assertEqual("ATANH(cola)", col.sql()) + + def test_cbrt(self): + col_str = SF.cbrt("cola") + self.assertEqual("CBRT(cola)", col_str.sql()) + col = SF.cbrt(SF.col("cola")) + self.assertEqual("CBRT(cola)", col.sql()) + + def test_ceil(self): + col_str = SF.ceil("cola") + self.assertEqual("CEIL(cola)", col_str.sql()) + col = SF.ceil(SF.col("cola")) + self.assertEqual("CEIL(cola)", col.sql()) + + def test_cos(self): + col_str = SF.cos("cola") + self.assertEqual("COS(cola)", col_str.sql()) + col = SF.cos(SF.col("cola")) + self.assertEqual("COS(cola)", col.sql()) + + def test_cosh(self): + col_str = SF.cosh("cola") + self.assertEqual("COSH(cola)", col_str.sql()) + col = SF.cosh(SF.col("cola")) + self.assertEqual("COSH(cola)", col.sql()) + + def test_cot(self): + col_str = SF.cot("cola") + self.assertEqual("COT(cola)", col_str.sql()) + col = SF.cot(SF.col("cola")) + self.assertEqual("COT(cola)", col.sql()) + + def test_csc(self): + col_str = SF.csc("cola") + self.assertEqual("CSC(cola)", col_str.sql()) + col = SF.csc(SF.col("cola")) + self.assertEqual("CSC(cola)", col.sql()) + + def test_exp(self): + col_str = SF.exp("cola") + self.assertEqual("EXP(cola)", col_str.sql()) + col = SF.exp(SF.col("cola")) + self.assertEqual("EXP(cola)", col.sql()) + + def test_expm1(self): + col_str = SF.expm1("cola") + self.assertEqual("EXPM1(cola)", col_str.sql()) + col = SF.expm1(SF.col("cola")) + self.assertEqual("EXPM1(cola)", col.sql()) + + def test_floor(self): + col_str = SF.floor("cola") + self.assertEqual("FLOOR(cola)", col_str.sql()) + col = SF.floor(SF.col("cola")) + self.assertEqual("FLOOR(cola)", col.sql()) + + def test_log(self): + col_str = SF.log("cola") + self.assertEqual("LN(cola)", col_str.sql()) + col = SF.log(SF.col("cola")) + self.assertEqual("LN(cola)", col.sql()) + col_arg = SF.log(10.0, "age") + self.assertEqual("LOG(10.0, age)", col_arg.sql()) + + def test_log10(self): + col_str = SF.log10("cola") + self.assertEqual("LOG10(cola)", col_str.sql()) + col = SF.log10(SF.col("cola")) + self.assertEqual("LOG10(cola)", col.sql()) + + def test_log1p(self): + col_str = SF.log1p("cola") + self.assertEqual("LOG1P(cola)", col_str.sql()) + col = SF.log1p(SF.col("cola")) + self.assertEqual("LOG1P(cola)", col.sql()) + + def test_log2(self): + col_str = SF.log2("cola") + self.assertEqual("LOG2(cola)", col_str.sql()) + col = SF.log2(SF.col("cola")) + self.assertEqual("LOG2(cola)", col.sql()) + + def test_rint(self): + col_str = SF.rint("cola") + self.assertEqual("RINT(cola)", col_str.sql()) + col = SF.rint(SF.col("cola")) + self.assertEqual("RINT(cola)", col.sql()) + + def test_sec(self): + col_str = SF.sec("cola") + self.assertEqual("SEC(cola)", col_str.sql()) + col = SF.sec(SF.col("cola")) + self.assertEqual("SEC(cola)", col.sql()) + + def test_signum(self): + col_str = SF.signum("cola") + self.assertEqual("SIGNUM(cola)", col_str.sql()) + col = SF.signum(SF.col("cola")) + self.assertEqual("SIGNUM(cola)", col.sql()) + + def test_sin(self): + col_str = SF.sin("cola") + self.assertEqual("SIN(cola)", col_str.sql()) + col = SF.sin(SF.col("cola")) + self.assertEqual("SIN(cola)", col.sql()) + + def test_sinh(self): + col_str = SF.sinh("cola") + self.assertEqual("SINH(cola)", col_str.sql()) + col = SF.sinh(SF.col("cola")) + self.assertEqual("SINH(cola)", col.sql()) + + def test_tan(self): + col_str = SF.tan("cola") + self.assertEqual("TAN(cola)", col_str.sql()) + col = SF.tan(SF.col("cola")) + self.assertEqual("TAN(cola)", col.sql()) + + def test_tanh(self): + col_str = SF.tanh("cola") + self.assertEqual("TANH(cola)", col_str.sql()) + col = SF.tanh(SF.col("cola")) + self.assertEqual("TANH(cola)", col.sql()) + + def test_degrees(self): + col_str = SF.degrees("cola") + self.assertEqual("DEGREES(cola)", col_str.sql()) + col = SF.degrees(SF.col("cola")) + self.assertEqual("DEGREES(cola)", col.sql()) + col_legacy = SF.toDegrees(SF.col("cola")) + self.assertEqual("DEGREES(cola)", col_legacy.sql()) + + def test_radians(self): + col_str = SF.radians("cola") + self.assertEqual("RADIANS(cola)", col_str.sql()) + col = SF.radians(SF.col("cola")) + self.assertEqual("RADIANS(cola)", col.sql()) + col_legacy = SF.toRadians(SF.col("cola")) + self.assertEqual("RADIANS(cola)", col_legacy.sql()) + + def test_bitwise_not(self): + col_str = SF.bitwise_not("cola") + self.assertEqual("~cola", col_str.sql()) + col = SF.bitwise_not(SF.col("cola")) + self.assertEqual("~cola", col.sql()) + col_legacy = SF.bitwiseNOT(SF.col("cola")) + self.assertEqual("~cola", col_legacy.sql()) + + def test_asc_nulls_first(self): + col_str = SF.asc_nulls_first("cola") + self.assertIsInstance(col_str.expression, exp.Ordered) + self.assertEqual("cola", col_str.sql()) + col = SF.asc_nulls_first(SF.col("cola")) + self.assertIsInstance(col.expression, exp.Ordered) + self.assertEqual("cola", col.sql()) + + def test_asc_nulls_last(self): + col_str = SF.asc_nulls_last("cola") + self.assertIsInstance(col_str.expression, exp.Ordered) + self.assertEqual("cola NULLS LAST", col_str.sql()) + col = SF.asc_nulls_last(SF.col("cola")) + self.assertIsInstance(col.expression, exp.Ordered) + self.assertEqual("cola NULLS LAST", col.sql()) + + def test_desc_nulls_first(self): + col_str = SF.desc_nulls_first("cola") + self.assertIsInstance(col_str.expression, exp.Ordered) + self.assertEqual("cola DESC NULLS FIRST", col_str.sql()) + col = SF.desc_nulls_first(SF.col("cola")) + self.assertIsInstance(col.expression, exp.Ordered) + self.assertEqual("cola DESC NULLS FIRST", col.sql()) + + def test_desc_nulls_last(self): + col_str = SF.desc_nulls_last("cola") + self.assertIsInstance(col_str.expression, exp.Ordered) + self.assertEqual("cola DESC", col_str.sql()) + col = SF.desc_nulls_last(SF.col("cola")) + self.assertIsInstance(col.expression, exp.Ordered) + self.assertEqual("cola DESC", col.sql()) + + def test_stddev(self): + col_str = SF.stddev("cola") + self.assertEqual("STDDEV(cola)", col_str.sql()) + col = SF.stddev(SF.col("cola")) + self.assertEqual("STDDEV(cola)", col.sql()) + + def test_stddev_samp(self): + col_str = SF.stddev_samp("cola") + self.assertEqual("STDDEV_SAMP(cola)", col_str.sql()) + col = SF.stddev_samp(SF.col("cola")) + self.assertEqual("STDDEV_SAMP(cola)", col.sql()) + + def test_stddev_pop(self): + col_str = SF.stddev_pop("cola") + self.assertEqual("STDDEV_POP(cola)", col_str.sql()) + col = SF.stddev_pop(SF.col("cola")) + self.assertEqual("STDDEV_POP(cola)", col.sql()) + + def test_variance(self): + col_str = SF.variance("cola") + self.assertEqual("VARIANCE(cola)", col_str.sql()) + col = SF.variance(SF.col("cola")) + self.assertEqual("VARIANCE(cola)", col.sql()) + + def test_var_samp(self): + col_str = SF.var_samp("cola") + self.assertEqual("VARIANCE(cola)", col_str.sql()) + col = SF.var_samp(SF.col("cola")) + self.assertEqual("VARIANCE(cola)", col.sql()) + + def test_var_pop(self): + col_str = SF.var_pop("cola") + self.assertEqual("VAR_POP(cola)", col_str.sql()) + col = SF.var_pop(SF.col("cola")) + self.assertEqual("VAR_POP(cola)", col.sql()) + + def test_skewness(self): + col_str = SF.skewness("cola") + self.assertEqual("SKEWNESS(cola)", col_str.sql()) + col = SF.skewness(SF.col("cola")) + self.assertEqual("SKEWNESS(cola)", col.sql()) + + def test_kurtosis(self): + col_str = SF.kurtosis("cola") + self.assertEqual("KURTOSIS(cola)", col_str.sql()) + col = SF.kurtosis(SF.col("cola")) + self.assertEqual("KURTOSIS(cola)", col.sql()) + + def test_collect_list(self): + col_str = SF.collect_list("cola") + self.assertEqual("COLLECT_LIST(cola)", col_str.sql()) + col = SF.collect_list(SF.col("cola")) + self.assertEqual("COLLECT_LIST(cola)", col.sql()) + + def test_collect_set(self): + col_str = SF.collect_set("cola") + self.assertEqual("COLLECT_SET(cola)", col_str.sql()) + col = SF.collect_set(SF.col("cola")) + self.assertEqual("COLLECT_SET(cola)", col.sql()) + + def test_hypot(self): + col_str = SF.hypot("cola", "colb") + self.assertEqual("HYPOT(cola, colb)", col_str.sql()) + col = SF.hypot(SF.col("cola"), SF.col("colb")) + self.assertEqual("HYPOT(cola, colb)", col.sql()) + col_float = SF.hypot(10.10, "colb") + self.assertEqual("HYPOT(10.1, colb)", col_float.sql()) + col_float2 = SF.hypot("cola", 10.10) + self.assertEqual("HYPOT(cola, 10.1)", col_float2.sql()) + + def test_pow(self): + col_str = SF.pow("cola", "colb") + self.assertEqual("POW(cola, colb)", col_str.sql()) + col = SF.pow(SF.col("cola"), SF.col("colb")) + self.assertEqual("POW(cola, colb)", col.sql()) + col_float = SF.pow(10.10, "colb") + self.assertEqual("POW(10.1, colb)", col_float.sql()) + col_float2 = SF.pow("cola", 10.10) + self.assertEqual("POW(cola, 10.1)", col_float2.sql()) + + def test_row_number(self): + col_str = SF.row_number() + self.assertEqual("ROW_NUMBER()", col_str.sql()) + col = SF.row_number() + self.assertEqual("ROW_NUMBER()", col.sql()) + + def test_dense_rank(self): + col_str = SF.dense_rank() + self.assertEqual("DENSE_RANK()", col_str.sql()) + col = SF.dense_rank() + self.assertEqual("DENSE_RANK()", col.sql()) + + def test_rank(self): + col_str = SF.rank() + self.assertEqual("RANK()", col_str.sql()) + col = SF.rank() + self.assertEqual("RANK()", col.sql()) + + def test_cume_dist(self): + col_str = SF.cume_dist() + self.assertEqual("CUME_DIST()", col_str.sql()) + col = SF.cume_dist() + self.assertEqual("CUME_DIST()", col.sql()) + + def test_percent_rank(self): + col_str = SF.percent_rank() + self.assertEqual("PERCENT_RANK()", col_str.sql()) + col = SF.percent_rank() + self.assertEqual("PERCENT_RANK()", col.sql()) + + def test_approx_count_distinct(self): + col_str = SF.approx_count_distinct("cola") + self.assertEqual("APPROX_COUNT_DISTINCT(cola)", col_str.sql()) + col_str_with_accuracy = SF.approx_count_distinct("cola", 0.05) + self.assertEqual("APPROX_COUNT_DISTINCT(cola, 0.05)", col_str_with_accuracy.sql()) + col = SF.approx_count_distinct(SF.col("cola")) + self.assertEqual("APPROX_COUNT_DISTINCT(cola)", col.sql()) + col_with_accuracy = SF.approx_count_distinct(SF.col("cola"), 0.05) + self.assertEqual("APPROX_COUNT_DISTINCT(cola, 0.05)", col_with_accuracy.sql()) + col_legacy = SF.approxCountDistinct(SF.col("cola")) + self.assertEqual("APPROX_COUNT_DISTINCT(cola)", col_legacy.sql()) + + def test_coalesce(self): + col_str = SF.coalesce("cola", "colb", "colc") + self.assertEqual("COALESCE(cola, colb, colc)", col_str.sql()) + col = SF.coalesce(SF.col("cola"), "colb", SF.col("colc")) + self.assertEqual("COALESCE(cola, colb, colc)", col.sql()) + + def test_corr(self): + col_str = SF.corr("cola", "colb") + self.assertEqual("CORR(cola, colb)", col_str.sql()) + col = SF.corr(SF.col("cola"), "colb") + self.assertEqual("CORR(cola, colb)", col.sql()) + + def test_covar_pop(self): + col_str = SF.covar_pop("cola", "colb") + self.assertEqual("COVAR_POP(cola, colb)", col_str.sql()) + col = SF.covar_pop(SF.col("cola"), "colb") + self.assertEqual("COVAR_POP(cola, colb)", col.sql()) + + def test_covar_samp(self): + col_str = SF.covar_samp("cola", "colb") + self.assertEqual("COVAR_SAMP(cola, colb)", col_str.sql()) + col = SF.covar_samp(SF.col("cola"), "colb") + self.assertEqual("COVAR_SAMP(cola, colb)", col.sql()) + + def test_count_distinct(self): + col_str = SF.count_distinct("cola") + self.assertEqual("COUNT(DISTINCT cola)", col_str.sql()) + col = SF.count_distinct(SF.col("cola")) + self.assertEqual("COUNT(DISTINCT cola)", col.sql()) + col_legacy = SF.countDistinct(SF.col("cola")) + self.assertEqual("COUNT(DISTINCT cola)", col_legacy.sql()) + col_multiple = SF.count_distinct(SF.col("cola"), SF.col("colb")) + self.assertEqual("COUNT(DISTINCT cola, colb)", col_multiple.sql()) + + def test_first(self): + col_str = SF.first("cola") + self.assertEqual("FIRST(cola)", col_str.sql()) + col = SF.first(SF.col("cola")) + self.assertEqual("FIRST(cola)", col.sql()) + ignore_nulls = SF.first("cola", True) + self.assertEqual("FIRST(cola, TRUE)", ignore_nulls.sql()) + + def test_grouping_id(self): + col_str = SF.grouping_id("cola", "colb") + self.assertEqual("GROUPING_ID(cola, colb)", col_str.sql()) + col = SF.grouping_id(SF.col("cola"), SF.col("colb")) + self.assertEqual("GROUPING_ID(cola, colb)", col.sql()) + col_grouping_no_arg = SF.grouping_id() + self.assertEqual("GROUPING_ID()", col_grouping_no_arg.sql()) + col_grouping_single_arg = SF.grouping_id("cola") + self.assertEqual("GROUPING_ID(cola)", col_grouping_single_arg.sql()) + + def test_input_file_name(self): + col = SF.input_file_name() + self.assertEqual("INPUT_FILE_NAME()", col.sql()) + + def test_isnan(self): + col_str = SF.isnan("cola") + self.assertEqual("ISNAN(cola)", col_str.sql()) + col = SF.isnan(SF.col("cola")) + self.assertEqual("ISNAN(cola)", col.sql()) + + def test_isnull(self): + col_str = SF.isnull("cola") + self.assertEqual("ISNULL(cola)", col_str.sql()) + col = SF.isnull(SF.col("cola")) + self.assertEqual("ISNULL(cola)", col.sql()) + + def test_last(self): + col_str = SF.last("cola") + self.assertEqual("LAST(cola)", col_str.sql()) + col = SF.last(SF.col("cola")) + self.assertEqual("LAST(cola)", col.sql()) + ignore_nulls = SF.last("cola", True) + self.assertEqual("LAST(cola, TRUE)", ignore_nulls.sql()) + + def test_monotonically_increasing_id(self): + col = SF.monotonically_increasing_id() + self.assertEqual("MONOTONICALLY_INCREASING_ID()", col.sql()) + + def test_nanvl(self): + col_str = SF.nanvl("cola", "colb") + self.assertEqual("NANVL(cola, colb)", col_str.sql()) + col = SF.nanvl(SF.col("cola"), SF.col("colb")) + self.assertEqual("NANVL(cola, colb)", col.sql()) + + def test_percentile_approx(self): + col_str = SF.percentile_approx("cola", [0.5, 0.4, 0.1]) + self.assertEqual("PERCENTILE_APPROX(cola, ARRAY(0.5, 0.4, 0.1))", col_str.sql()) + col = SF.percentile_approx(SF.col("cola"), [0.5, 0.4, 0.1]) + self.assertEqual("PERCENTILE_APPROX(cola, ARRAY(0.5, 0.4, 0.1))", col.sql()) + col_accuracy = SF.percentile_approx("cola", 0.1, 100) + self.assertEqual("PERCENTILE_APPROX(cola, 0.1, 100)", col_accuracy.sql()) + + def test_rand(self): + col_str = SF.rand(SF.lit(0)) + self.assertEqual("RAND(0)", col_str.sql()) + col = SF.rand(SF.lit(0)) + self.assertEqual("RAND(0)", col.sql()) + no_col = SF.rand() + self.assertEqual("RAND()", no_col.sql()) + + def test_randn(self): + col_str = SF.randn(0) + self.assertEqual("RANDN(0)", col_str.sql()) + col = SF.randn(0) + self.assertEqual("RANDN(0)", col.sql()) + no_col = SF.randn() + self.assertEqual("RANDN()", no_col.sql()) + + def test_round(self): + col_str = SF.round("cola", 0) + self.assertEqual("ROUND(cola, 0)", col_str.sql()) + col = SF.round(SF.col("cola"), 0) + self.assertEqual("ROUND(cola, 0)", col.sql()) + col_no_scale = SF.round("cola") + self.assertEqual("ROUND(cola)", col_no_scale.sql()) + + def test_bround(self): + col_str = SF.bround("cola", 0) + self.assertEqual("BROUND(cola, 0)", col_str.sql()) + col = SF.bround(SF.col("cola"), 0) + self.assertEqual("BROUND(cola, 0)", col.sql()) + col_no_scale = SF.bround("cola") + self.assertEqual("BROUND(cola)", col_no_scale.sql()) + + def test_shiftleft(self): + col_str = SF.shiftleft("cola", 1) + self.assertEqual("SHIFTLEFT(cola, 1)", col_str.sql()) + col = SF.shiftleft(SF.col("cola"), 1) + self.assertEqual("SHIFTLEFT(cola, 1)", col.sql()) + col_legacy = SF.shiftLeft(SF.col("cola"), 1) + self.assertEqual("SHIFTLEFT(cola, 1)", col_legacy.sql()) + + def test_shiftright(self): + col_str = SF.shiftright("cola", 1) + self.assertEqual("SHIFTRIGHT(cola, 1)", col_str.sql()) + col = SF.shiftright(SF.col("cola"), 1) + self.assertEqual("SHIFTRIGHT(cola, 1)", col.sql()) + col_legacy = SF.shiftRight(SF.col("cola"), 1) + self.assertEqual("SHIFTRIGHT(cola, 1)", col_legacy.sql()) + + def test_shiftrightunsigned(self): + col_str = SF.shiftrightunsigned("cola", 1) + self.assertEqual("SHIFTRIGHTUNSIGNED(cola, 1)", col_str.sql()) + col = SF.shiftrightunsigned(SF.col("cola"), 1) + self.assertEqual("SHIFTRIGHTUNSIGNED(cola, 1)", col.sql()) + col_legacy = SF.shiftRightUnsigned(SF.col("cola"), 1) + self.assertEqual("SHIFTRIGHTUNSIGNED(cola, 1)", col_legacy.sql()) + + def test_expr(self): + col_str = SF.expr("LENGTH(name)") + self.assertEqual("LENGTH(name)", col_str.sql()) + + def test_struct(self): + col_str = SF.struct("cola", "colb", "colc") + self.assertEqual("STRUCT(cola, colb, colc)", col_str.sql()) + col = SF.struct(SF.col("cola"), SF.col("colb"), SF.col("colc")) + self.assertEqual("STRUCT(cola, colb, colc)", col.sql()) + col_single = SF.struct("cola") + self.assertEqual("STRUCT(cola)", col_single.sql()) + col_list = SF.struct(["cola", "colb", "colc"]) + self.assertEqual("STRUCT(cola, colb, colc)", col_list.sql()) + + def test_greatest(self): + single_str = SF.greatest("cola") + self.assertEqual("GREATEST(cola)", single_str.sql()) + single_col = SF.greatest(SF.col("cola")) + self.assertEqual("GREATEST(cola)", single_col.sql()) + multiple_mix = SF.greatest("col1", "col2", SF.col("col3"), SF.col("col4")) + self.assertEqual("GREATEST(col1, col2, col3, col4)", multiple_mix.sql()) + + def test_least(self): + single_str = SF.least("cola") + self.assertEqual("LEAST(cola)", single_str.sql()) + single_col = SF.least(SF.col("cola")) + self.assertEqual("LEAST(cola)", single_col.sql()) + multiple_mix = SF.least("col1", "col2", SF.col("col3"), SF.col("col4")) + self.assertEqual("LEAST(col1, col2, col3, col4)", multiple_mix.sql()) + + def test_when(self): + col_simple = SF.when(SF.col("cola") == 2, 1) + self.assertEqual("CASE WHEN cola = 2 THEN 1 END", col_simple.sql()) + col_complex = SF.when(SF.col("cola") == 2, SF.col("colb") + 2) + self.assertEqual("CASE WHEN cola = 2 THEN colb + 2 END", col_complex.sql()) + + def test_conv(self): + col_str = SF.conv("cola", 2, 16) + self.assertEqual("CONV(cola, 2, 16)", col_str.sql()) + col = SF.conv(SF.col("cola"), 2, 16) + self.assertEqual("CONV(cola, 2, 16)", col.sql()) + + def test_factorial(self): + col_str = SF.factorial("cola") + self.assertEqual("FACTORIAL(cola)", col_str.sql()) + col = SF.factorial(SF.col("cola")) + self.assertEqual("FACTORIAL(cola)", col.sql()) + + def test_lag(self): + col_str = SF.lag("cola", 3, "colc") + self.assertEqual("LAG(cola, 3, colc)", col_str.sql()) + col = SF.lag(SF.col("cola"), 3, "colc") + self.assertEqual("LAG(cola, 3, colc)", col.sql()) + col_no_default = SF.lag("cola", 3) + self.assertEqual("LAG(cola, 3)", col_no_default.sql()) + col_no_offset = SF.lag("cola") + self.assertEqual("LAG(cola)", col_no_offset.sql()) + + def test_lead(self): + col_str = SF.lead("cola", 3, "colc") + self.assertEqual("LEAD(cola, 3, colc)", col_str.sql()) + col = SF.lead(SF.col("cola"), 3, "colc") + self.assertEqual("LEAD(cola, 3, colc)", col.sql()) + col_no_default = SF.lead("cola", 3) + self.assertEqual("LEAD(cola, 3)", col_no_default.sql()) + col_no_offset = SF.lead("cola") + self.assertEqual("LEAD(cola)", col_no_offset.sql()) + + def test_nth_value(self): + col_str = SF.nth_value("cola", 3) + self.assertEqual("NTH_VALUE(cola, 3)", col_str.sql()) + col = SF.nth_value(SF.col("cola"), 3) + self.assertEqual("NTH_VALUE(cola, 3)", col.sql()) + col_no_offset = SF.nth_value("cola") + self.assertEqual("NTH_VALUE(cola)", col_no_offset.sql()) + with self.assertRaises(NotImplementedError): + SF.nth_value("cola", ignoreNulls=True) + + def test_ntile(self): + col = SF.ntile(2) + self.assertEqual("NTILE(2)", col.sql()) + + def test_current_date(self): + col = SF.current_date() + self.assertEqual("CURRENT_DATE", col.sql()) + + def test_current_timestamp(self): + col = SF.current_timestamp() + self.assertEqual("CURRENT_TIMESTAMP()", col.sql()) + + def test_date_format(self): + col_str = SF.date_format("cola", "MM/dd/yyy") + self.assertEqual("DATE_FORMAT(cola, 'MM/dd/yyy')", col_str.sql()) + col = SF.date_format(SF.col("cola"), "MM/dd/yyy") + self.assertEqual("DATE_FORMAT(cola, 'MM/dd/yyy')", col.sql()) + + def test_year(self): + col_str = SF.year("cola") + self.assertEqual("YEAR(cola)", col_str.sql()) + col = SF.year(SF.col("cola")) + self.assertEqual("YEAR(cola)", col.sql()) + + def test_quarter(self): + col_str = SF.quarter("cola") + self.assertEqual("QUARTER(cola)", col_str.sql()) + col = SF.quarter(SF.col("cola")) + self.assertEqual("QUARTER(cola)", col.sql()) + + def test_month(self): + col_str = SF.month("cola") + self.assertEqual("MONTH(cola)", col_str.sql()) + col = SF.month(SF.col("cola")) + self.assertEqual("MONTH(cola)", col.sql()) + + def test_dayofweek(self): + col_str = SF.dayofweek("cola") + self.assertEqual("DAYOFWEEK(cola)", col_str.sql()) + col = SF.dayofweek(SF.col("cola")) + self.assertEqual("DAYOFWEEK(cola)", col.sql()) + + def test_dayofmonth(self): + col_str = SF.dayofmonth("cola") + self.assertEqual("DAYOFMONTH(cola)", col_str.sql()) + col = SF.dayofmonth(SF.col("cola")) + self.assertEqual("DAYOFMONTH(cola)", col.sql()) + + def test_dayofyear(self): + col_str = SF.dayofyear("cola") + self.assertEqual("DAYOFYEAR(cola)", col_str.sql()) + col = SF.dayofyear(SF.col("cola")) + self.assertEqual("DAYOFYEAR(cola)", col.sql()) + + def test_hour(self): + col_str = SF.hour("cola") + self.assertEqual("HOUR(cola)", col_str.sql()) + col = SF.hour(SF.col("cola")) + self.assertEqual("HOUR(cola)", col.sql()) + + def test_minute(self): + col_str = SF.minute("cola") + self.assertEqual("MINUTE(cola)", col_str.sql()) + col = SF.minute(SF.col("cola")) + self.assertEqual("MINUTE(cola)", col.sql()) + + def test_second(self): + col_str = SF.second("cola") + self.assertEqual("SECOND(cola)", col_str.sql()) + col = SF.second(SF.col("cola")) + self.assertEqual("SECOND(cola)", col.sql()) + + def test_weekofyear(self): + col_str = SF.weekofyear("cola") + self.assertEqual("WEEKOFYEAR(cola)", col_str.sql()) + col = SF.weekofyear(SF.col("cola")) + self.assertEqual("WEEKOFYEAR(cola)", col.sql()) + + def test_make_date(self): + col_str = SF.make_date("cola", "colb", "colc") + self.assertEqual("MAKE_DATE(cola, colb, colc)", col_str.sql()) + col = SF.make_date(SF.col("cola"), SF.col("colb"), "colc") + self.assertEqual("MAKE_DATE(cola, colb, colc)", col.sql()) + + def test_date_add(self): + col_str = SF.date_add("cola", 2) + self.assertEqual("DATE_ADD(cola, 2)", col_str.sql()) + col = SF.date_add(SF.col("cola"), 2) + self.assertEqual("DATE_ADD(cola, 2)", col.sql()) + col_col_for_add = SF.date_add("cola", "colb") + self.assertEqual("DATE_ADD(cola, colb)", col_col_for_add.sql()) + + def test_date_sub(self): + col_str = SF.date_sub("cola", 2) + self.assertEqual("DATE_SUB(cola, 2)", col_str.sql()) + col = SF.date_sub(SF.col("cola"), 2) + self.assertEqual("DATE_SUB(cola, 2)", col.sql()) + col_col_for_add = SF.date_sub("cola", "colb") + self.assertEqual("DATE_SUB(cola, colb)", col_col_for_add.sql()) + + def test_date_diff(self): + col_str = SF.date_diff("cola", "colb") + self.assertEqual("DATEDIFF(cola, colb)", col_str.sql()) + col = SF.date_diff(SF.col("cola"), SF.col("colb")) + self.assertEqual("DATEDIFF(cola, colb)", col.sql()) + + def test_add_months(self): + col_str = SF.add_months("cola", 2) + self.assertEqual("ADD_MONTHS(cola, 2)", col_str.sql()) + col = SF.add_months(SF.col("cola"), 2) + self.assertEqual("ADD_MONTHS(cola, 2)", col.sql()) + col_col_for_add = SF.add_months("cola", "colb") + self.assertEqual("ADD_MONTHS(cola, colb)", col_col_for_add.sql()) + + def test_months_between(self): + col_str = SF.months_between("cola", "colb") + self.assertEqual("MONTHS_BETWEEN(cola, colb)", col_str.sql()) + col = SF.months_between(SF.col("cola"), SF.col("colb")) + self.assertEqual("MONTHS_BETWEEN(cola, colb)", col.sql()) + col_round_off = SF.months_between("cola", "colb", True) + self.assertEqual("MONTHS_BETWEEN(cola, colb, TRUE)", col_round_off.sql()) + + def test_to_date(self): + col_str = SF.to_date("cola") + self.assertEqual("TO_DATE(cola)", col_str.sql()) + col = SF.to_date(SF.col("cola")) + self.assertEqual("TO_DATE(cola)", col.sql()) + col_with_format = SF.to_date("cola", "yyyy-MM-dd") + self.assertEqual("TO_DATE(cola, 'yyyy-MM-dd')", col_with_format.sql()) + + def test_to_timestamp(self): + col_str = SF.to_timestamp("cola") + self.assertEqual("TO_TIMESTAMP(cola)", col_str.sql()) + col = SF.to_timestamp(SF.col("cola")) + self.assertEqual("TO_TIMESTAMP(cola)", col.sql()) + col_with_format = SF.to_timestamp("cola", "yyyy-MM-dd") + self.assertEqual("TO_TIMESTAMP(cola, 'yyyy-MM-dd')", col_with_format.sql()) + + def test_trunc(self): + col_str = SF.trunc("cola", "year") + self.assertEqual("TRUNC(cola, 'year')", col_str.sql()) + col = SF.trunc(SF.col("cola"), "year") + self.assertEqual("TRUNC(cola, 'year')", col.sql()) + + def test_date_trunc(self): + col_str = SF.date_trunc("year", "cola") + self.assertEqual("DATE_TRUNC('year', cola)", col_str.sql()) + col = SF.date_trunc("year", SF.col("cola")) + self.assertEqual("DATE_TRUNC('year', cola)", col.sql()) + + def test_next_day(self): + col_str = SF.next_day("cola", "Mon") + self.assertEqual("NEXT_DAY(cola, 'Mon')", col_str.sql()) + col = SF.next_day(SF.col("cola"), "Mon") + self.assertEqual("NEXT_DAY(cola, 'Mon')", col.sql()) + + def test_last_day(self): + col_str = SF.last_day("cola") + self.assertEqual("LAST_DAY(cola)", col_str.sql()) + col = SF.last_day(SF.col("cola")) + self.assertEqual("LAST_DAY(cola)", col.sql()) + + def test_from_unixtime(self): + col_str = SF.from_unixtime("cola") + self.assertEqual("FROM_UNIXTIME(cola)", col_str.sql()) + col = SF.from_unixtime(SF.col("cola")) + self.assertEqual("FROM_UNIXTIME(cola)", col.sql()) + col_format = SF.from_unixtime("cola", "yyyy-MM-dd HH:mm:ss") + self.assertEqual("FROM_UNIXTIME(cola, 'yyyy-MM-dd HH:mm:ss')", col_format.sql()) + + def test_unix_timestamp(self): + col_str = SF.unix_timestamp("cola") + self.assertEqual("UNIX_TIMESTAMP(cola)", col_str.sql()) + col = SF.unix_timestamp(SF.col("cola")) + self.assertEqual("UNIX_TIMESTAMP(cola)", col.sql()) + col_format = SF.unix_timestamp("cola", "yyyy-MM-dd HH:mm:ss") + self.assertEqual("UNIX_TIMESTAMP(cola, 'yyyy-MM-dd HH:mm:ss')", col_format.sql()) + col_current = SF.unix_timestamp() + self.assertEqual("UNIX_TIMESTAMP()", col_current.sql()) + + def test_from_utc_timestamp(self): + col_str = SF.from_utc_timestamp("cola", "PST") + self.assertEqual("FROM_UTC_TIMESTAMP(cola, 'PST')", col_str.sql()) + col = SF.from_utc_timestamp(SF.col("cola"), "PST") + self.assertEqual("FROM_UTC_TIMESTAMP(cola, 'PST')", col.sql()) + col_col = SF.from_utc_timestamp("cola", SF.col("colb")) + self.assertEqual("FROM_UTC_TIMESTAMP(cola, colb)", col_col.sql()) + + def test_to_utc_timestamp(self): + col_str = SF.to_utc_timestamp("cola", "PST") + self.assertEqual("TO_UTC_TIMESTAMP(cola, 'PST')", col_str.sql()) + col = SF.to_utc_timestamp(SF.col("cola"), "PST") + self.assertEqual("TO_UTC_TIMESTAMP(cola, 'PST')", col.sql()) + col_col = SF.to_utc_timestamp("cola", SF.col("colb")) + self.assertEqual("TO_UTC_TIMESTAMP(cola, colb)", col_col.sql()) + + def test_timestamp_seconds(self): + col_str = SF.timestamp_seconds("cola") + self.assertEqual("TIMESTAMP_SECONDS(cola)", col_str.sql()) + col = SF.timestamp_seconds(SF.col("cola")) + self.assertEqual("TIMESTAMP_SECONDS(cola)", col.sql()) + + def test_window(self): + col_str = SF.window("cola", "10 minutes") + self.assertEqual("WINDOW(cola, '10 minutes')", col_str.sql()) + col = SF.window(SF.col("cola"), "10 minutes") + self.assertEqual("WINDOW(cola, '10 minutes')", col.sql()) + col_all_values = SF.window("cola", "2 minutes 30 seconds", "30 seconds", "15 seconds") + self.assertEqual("WINDOW(cola, '2 minutes 30 seconds', '30 seconds', '15 seconds')", col_all_values.sql()) + col_no_start_time = SF.window("cola", "2 minutes 30 seconds", "30 seconds") + self.assertEqual("WINDOW(cola, '2 minutes 30 seconds', '30 seconds')", col_no_start_time.sql()) + col_no_slide = SF.window("cola", "2 minutes 30 seconds", startTime="15 seconds") + self.assertEqual( + "WINDOW(cola, '2 minutes 30 seconds', '2 minutes 30 seconds', '15 seconds')", col_no_slide.sql() + ) + + def test_session_window(self): + col_str = SF.session_window("cola", "5 seconds") + self.assertEqual("SESSION_WINDOW(cola, '5 seconds')", col_str.sql()) + col = SF.session_window(SF.col("cola"), SF.lit("5 seconds")) + self.assertEqual("SESSION_WINDOW(cola, '5 seconds')", col.sql()) + + def test_crc32(self): + col_str = SF.crc32("Spark") + self.assertEqual("CRC32('Spark')", col_str.sql()) + col = SF.crc32(SF.col("cola")) + self.assertEqual("CRC32(cola)", col.sql()) + + def test_md5(self): + col_str = SF.md5("Spark") + self.assertEqual("MD5('Spark')", col_str.sql()) + col = SF.md5(SF.col("cola")) + self.assertEqual("MD5(cola)", col.sql()) + + def test_sha1(self): + col_str = SF.sha1("Spark") + self.assertEqual("SHA1('Spark')", col_str.sql()) + col = SF.sha1(SF.col("cola")) + self.assertEqual("SHA1(cola)", col.sql()) + + def test_sha2(self): + col_str = SF.sha2("Spark", 256) + self.assertEqual("SHA2('Spark', 256)", col_str.sql()) + col = SF.sha2(SF.col("cola"), 256) + self.assertEqual("SHA2(cola, 256)", col.sql()) + + def test_hash(self): + col_str = SF.hash("cola", "colb", "colc") + self.assertEqual("HASH(cola, colb, colc)", col_str.sql()) + col = SF.hash(SF.col("cola"), SF.col("colb"), SF.col("colc")) + self.assertEqual("HASH(cola, colb, colc)", col.sql()) + + def test_xxhash64(self): + col_str = SF.xxhash64("cola", "colb", "colc") + self.assertEqual("XXHASH64(cola, colb, colc)", col_str.sql()) + col = SF.xxhash64(SF.col("cola"), SF.col("colb"), SF.col("colc")) + self.assertEqual("XXHASH64(cola, colb, colc)", col.sql()) + + def test_assert_true(self): + col = SF.assert_true(SF.col("cola") < SF.col("colb")) + self.assertEqual("ASSERT_TRUE(cola < colb)", col.sql()) + col_error_msg_col = SF.assert_true(SF.col("cola") < SF.col("colb"), SF.col("colc")) + self.assertEqual("ASSERT_TRUE(cola < colb, colc)", col_error_msg_col.sql()) + col_error_msg_lit = SF.assert_true(SF.col("cola") < SF.col("colb"), "error") + self.assertEqual("ASSERT_TRUE(cola < colb, 'error')", col_error_msg_lit.sql()) + + def test_raise_error(self): + col_str = SF.raise_error("custom error message") + self.assertEqual("RAISE_ERROR('custom error message')", col_str.sql()) + col = SF.raise_error(SF.col("cola")) + self.assertEqual("RAISE_ERROR(cola)", col.sql()) + + def test_upper(self): + col_str = SF.upper("cola") + self.assertEqual("UPPER(cola)", col_str.sql()) + col = SF.upper(SF.col("cola")) + self.assertEqual("UPPER(cola)", col.sql()) + + def test_lower(self): + col_str = SF.lower("cola") + self.assertEqual("LOWER(cola)", col_str.sql()) + col = SF.lower(SF.col("cola")) + self.assertEqual("LOWER(cola)", col.sql()) + + def test_ascii(self): + col_str = SF.ascii(SF.lit(2)) + self.assertEqual("ASCII(2)", col_str.sql()) + col = SF.ascii(SF.col("cola")) + self.assertEqual("ASCII(cola)", col.sql()) + + def test_base64(self): + col_str = SF.base64(SF.lit(2)) + self.assertEqual("BASE64(2)", col_str.sql()) + col = SF.base64(SF.col("cola")) + self.assertEqual("BASE64(cola)", col.sql()) + + def test_unbase64(self): + col_str = SF.unbase64(SF.lit(2)) + self.assertEqual("UNBASE64(2)", col_str.sql()) + col = SF.unbase64(SF.col("cola")) + self.assertEqual("UNBASE64(cola)", col.sql()) + + def test_ltrim(self): + col_str = SF.ltrim(SF.lit("Spark")) + self.assertEqual("LTRIM('Spark')", col_str.sql()) + col = SF.ltrim(SF.col("cola")) + self.assertEqual("LTRIM(cola)", col.sql()) + + def test_rtrim(self): + col_str = SF.rtrim(SF.lit("Spark")) + self.assertEqual("RTRIM('Spark')", col_str.sql()) + col = SF.rtrim(SF.col("cola")) + self.assertEqual("RTRIM(cola)", col.sql()) + + def test_trim(self): + col_str = SF.trim(SF.lit("Spark")) + self.assertEqual("TRIM('Spark')", col_str.sql()) + col = SF.trim(SF.col("cola")) + self.assertEqual("TRIM(cola)", col.sql()) + + def test_concat_ws(self): + col_str = SF.concat_ws("-", "cola", "colb") + self.assertEqual("CONCAT_WS('-', cola, colb)", col_str.sql()) + col = SF.concat_ws("-", SF.col("cola"), SF.col("colb")) + self.assertEqual("CONCAT_WS('-', cola, colb)", col.sql()) + + def test_decode(self): + col_str = SF.decode("cola", "US-ASCII") + self.assertEqual("DECODE(cola, 'US-ASCII')", col_str.sql()) + col = SF.decode(SF.col("cola"), "US-ASCII") + self.assertEqual("DECODE(cola, 'US-ASCII')", col.sql()) + + def test_encode(self): + col_str = SF.encode("cola", "US-ASCII") + self.assertEqual("ENCODE(cola, 'US-ASCII')", col_str.sql()) + col = SF.encode(SF.col("cola"), "US-ASCII") + self.assertEqual("ENCODE(cola, 'US-ASCII')", col.sql()) + + def test_format_number(self): + col_str = SF.format_number("cola", 4) + self.assertEqual("FORMAT_NUMBER(cola, 4)", col_str.sql()) + col = SF.format_number(SF.col("cola"), 4) + self.assertEqual("FORMAT_NUMBER(cola, 4)", col.sql()) + + def test_format_string(self): + col_str = SF.format_string("%d %s", "cola", "colb", "colc") + self.assertEqual("FORMAT_STRING('%d %s', cola, colb, colc)", col_str.sql()) + col = SF.format_string("%d %s", SF.col("cola"), SF.col("colb"), SF.col("colc")) + self.assertEqual("FORMAT_STRING('%d %s', cola, colb, colc)", col.sql()) + + def test_instr(self): + col_str = SF.instr("cola", "test") + self.assertEqual("INSTR(cola, 'test')", col_str.sql()) + col = SF.instr(SF.col("cola"), "test") + self.assertEqual("INSTR(cola, 'test')", col.sql()) + + def test_overlay(self): + col_str = SF.overlay("cola", "colb", 3, 7) + self.assertEqual("OVERLAY(cola, colb, 3, 7)", col_str.sql()) + col = SF.overlay(SF.col("cola"), SF.col("colb"), SF.lit(3), SF.lit(7)) + self.assertEqual("OVERLAY(cola, colb, 3, 7)", col.sql()) + col_no_length = SF.overlay("cola", "colb", 3) + self.assertEqual("OVERLAY(cola, colb, 3)", col_no_length.sql()) + + def test_sentences(self): + col_str = SF.sentences("cola", SF.lit("en"), SF.lit("US")) + self.assertEqual("SENTENCES(cola, 'en', 'US')", col_str.sql()) + col = SF.sentences(SF.col("cola"), SF.lit("en"), SF.lit("US")) + self.assertEqual("SENTENCES(cola, 'en', 'US')", col.sql()) + col_no_country = SF.sentences("cola", SF.lit("en")) + self.assertEqual("SENTENCES(cola, 'en')", col_no_country.sql()) + col_no_lang = SF.sentences(SF.col("cola"), country=SF.lit("US")) + self.assertEqual("SENTENCES(cola, 'en', 'US')", col_no_lang.sql()) + col_defaults = SF.sentences(SF.col("cola")) + self.assertEqual("SENTENCES(cola)", col_defaults.sql()) + + def test_substring(self): + col_str = SF.substring("cola", 2, 3) + self.assertEqual("SUBSTRING(cola, 2, 3)", col_str.sql()) + col = SF.substring(SF.col("cola"), 2, 3) + self.assertEqual("SUBSTRING(cola, 2, 3)", col.sql()) + + def test_substring_index(self): + col_str = SF.substring_index("cola", ".", 2) + self.assertEqual("SUBSTRING_INDEX(cola, '.', 2)", col_str.sql()) + col = SF.substring_index(SF.col("cola"), ".", 2) + self.assertEqual("SUBSTRING_INDEX(cola, '.', 2)", col.sql()) + + def test_levenshtein(self): + col_str = SF.levenshtein("cola", "colb") + self.assertEqual("LEVENSHTEIN(cola, colb)", col_str.sql()) + col = SF.levenshtein(SF.col("cola"), SF.col("colb")) + self.assertEqual("LEVENSHTEIN(cola, colb)", col.sql()) + + def test_locate(self): + col_str = SF.locate("test", "cola", 3) + self.assertEqual("LOCATE('test', cola, 3)", col_str.sql()) + col = SF.locate("test", SF.col("cola"), 3) + self.assertEqual("LOCATE('test', cola, 3)", col.sql()) + col_no_pos = SF.locate("test", "cola") + self.assertEqual("LOCATE('test', cola)", col_no_pos.sql()) + + def test_lpad(self): + col_str = SF.lpad("cola", 3, "#") + self.assertEqual("LPAD(cola, 3, '#')", col_str.sql()) + col = SF.lpad(SF.col("cola"), 3, "#") + self.assertEqual("LPAD(cola, 3, '#')", col.sql()) + + def test_rpad(self): + col_str = SF.rpad("cola", 3, "#") + self.assertEqual("RPAD(cola, 3, '#')", col_str.sql()) + col = SF.rpad(SF.col("cola"), 3, "#") + self.assertEqual("RPAD(cola, 3, '#')", col.sql()) + + def test_repeat(self): + col_str = SF.repeat("cola", 3) + self.assertEqual("REPEAT(cola, 3)", col_str.sql()) + col = SF.repeat(SF.col("cola"), 3) + self.assertEqual("REPEAT(cola, 3)", col.sql()) + + def test_split(self): + col_str = SF.split("cola", "[ABC]", 3) + self.assertEqual("SPLIT(cola, '[ABC]', 3)", col_str.sql()) + col = SF.split(SF.col("cola"), "[ABC]", 3) + self.assertEqual("SPLIT(cola, '[ABC]', 3)", col.sql()) + col_no_limit = SF.split("cola", "[ABC]") + self.assertEqual("SPLIT(cola, '[ABC]')", col_no_limit.sql()) + + def test_regexp_extract(self): + col_str = SF.regexp_extract("cola", r"(\d+)-(\d+)", 1) + self.assertEqual("REGEXP_EXTRACT(cola, '(\\\d+)-(\\\d+)', 1)", col_str.sql()) + col = SF.regexp_extract(SF.col("cola"), r"(\d+)-(\d+)", 1) + self.assertEqual("REGEXP_EXTRACT(cola, '(\\\d+)-(\\\d+)', 1)", col.sql()) + col_no_idx = SF.regexp_extract(SF.col("cola"), r"(\d+)-(\d+)") + self.assertEqual("REGEXP_EXTRACT(cola, '(\\\d+)-(\\\d+)')", col_no_idx.sql()) + + def test_regexp_replace(self): + col_str = SF.regexp_replace("cola", r"(\d+)", "--") + self.assertEqual("REGEXP_REPLACE(cola, '(\\\d+)', '--')", col_str.sql()) + col = SF.regexp_replace(SF.col("cola"), r"(\d+)", "--") + self.assertEqual("REGEXP_REPLACE(cola, '(\\\d+)', '--')", col.sql()) + + def test_initcap(self): + col_str = SF.initcap("cola") + self.assertEqual("INITCAP(cola)", col_str.sql()) + col = SF.initcap(SF.col("cola")) + self.assertEqual("INITCAP(cola)", col.sql()) + + def test_soundex(self): + col_str = SF.soundex("cola") + self.assertEqual("SOUNDEX(cola)", col_str.sql()) + col = SF.soundex(SF.col("cola")) + self.assertEqual("SOUNDEX(cola)", col.sql()) + + def test_bin(self): + col_str = SF.bin("cola") + self.assertEqual("BIN(cola)", col_str.sql()) + col = SF.bin(SF.col("cola")) + self.assertEqual("BIN(cola)", col.sql()) + + def test_hex(self): + col_str = SF.hex("cola") + self.assertEqual("HEX(cola)", col_str.sql()) + col = SF.hex(SF.col("cola")) + self.assertEqual("HEX(cola)", col.sql()) + + def test_unhex(self): + col_str = SF.unhex("cola") + self.assertEqual("UNHEX(cola)", col_str.sql()) + col = SF.unhex(SF.col("cola")) + self.assertEqual("UNHEX(cola)", col.sql()) + + def test_length(self): + col_str = SF.length("cola") + self.assertEqual("LENGTH(cola)", col_str.sql()) + col = SF.length(SF.col("cola")) + self.assertEqual("LENGTH(cola)", col.sql()) + + def test_octet_length(self): + col_str = SF.octet_length("cola") + self.assertEqual("OCTET_LENGTH(cola)", col_str.sql()) + col = SF.octet_length(SF.col("cola")) + self.assertEqual("OCTET_LENGTH(cola)", col.sql()) + + def test_bit_length(self): + col_str = SF.bit_length("cola") + self.assertEqual("BIT_LENGTH(cola)", col_str.sql()) + col = SF.bit_length(SF.col("cola")) + self.assertEqual("BIT_LENGTH(cola)", col.sql()) + + def test_translate(self): + col_str = SF.translate("cola", "abc", "xyz") + self.assertEqual("TRANSLATE(cola, 'abc', 'xyz')", col_str.sql()) + col = SF.translate(SF.col("cola"), "abc", "xyz") + self.assertEqual("TRANSLATE(cola, 'abc', 'xyz')", col.sql()) + + def test_array(self): + col_str = SF.array("cola", "colb") + self.assertEqual("ARRAY(cola, colb)", col_str.sql()) + col = SF.array(SF.col("cola"), SF.col("colb")) + self.assertEqual("ARRAY(cola, colb)", col.sql()) + col_array = SF.array(["cola", "colb"]) + self.assertEqual("ARRAY(cola, colb)", col_array.sql()) + + def test_create_map(self): + col_str = SF.create_map("keya", "valuea", "keyb", "valueb") + self.assertEqual("MAP(keya, valuea, keyb, valueb)", col_str.sql()) + col = SF.create_map(SF.col("keya"), SF.col("valuea"), SF.col("keyb"), SF.col("valueb")) + self.assertEqual("MAP(keya, valuea, keyb, valueb)", col.sql()) + col_array = SF.create_map(["keya", "valuea", "keyb", "valueb"]) + self.assertEqual("MAP(keya, valuea, keyb, valueb)", col_array.sql()) + + def test_map_from_arrays(self): + col_str = SF.map_from_arrays("cola", "colb") + self.assertEqual("MAP_FROM_ARRAYS(cola, colb)", col_str.sql()) + col = SF.map_from_arrays(SF.col("cola"), SF.col("colb")) + self.assertEqual("MAP_FROM_ARRAYS(cola, colb)", col.sql()) + + def test_array_contains(self): + col_str = SF.array_contains("cola", "test") + self.assertEqual("ARRAY_CONTAINS(cola, 'test')", col_str.sql()) + col = SF.array_contains(SF.col("cola"), "test") + self.assertEqual("ARRAY_CONTAINS(cola, 'test')", col.sql()) + col_as_value = SF.array_contains("cola", SF.col("colb")) + self.assertEqual("ARRAY_CONTAINS(cola, colb)", col_as_value.sql()) + + def test_arrays_overlap(self): + col_str = SF.arrays_overlap("cola", "colb") + self.assertEqual("ARRAYS_OVERLAP(cola, colb)", col_str.sql()) + col = SF.arrays_overlap(SF.col("cola"), SF.col("colb")) + self.assertEqual("ARRAYS_OVERLAP(cola, colb)", col.sql()) + + def test_slice(self): + col_str = SF.slice("cola", SF.col("colb"), SF.col("colc")) + self.assertEqual("SLICE(cola, colb, colc)", col_str.sql()) + col = SF.slice(SF.col("cola"), SF.col("colb"), SF.col("colc")) + self.assertEqual("SLICE(cola, colb, colc)", col.sql()) + col_ints = SF.slice("cola", 1, 10) + self.assertEqual("SLICE(cola, 1, 10)", col_ints.sql()) + + def test_array_join(self): + col_str = SF.array_join("cola", "-", "NULL_REPLACEMENT") + self.assertEqual("ARRAY_JOIN(cola, '-', 'NULL_REPLACEMENT')", col_str.sql()) + col = SF.array_join(SF.col("cola"), "-", "NULL_REPLACEMENT") + self.assertEqual("ARRAY_JOIN(cola, '-', 'NULL_REPLACEMENT')", col.sql()) + col_no_replacement = SF.array_join("cola", "-") + self.assertEqual("ARRAY_JOIN(cola, '-')", col_no_replacement.sql()) + + def test_concat(self): + col_str = SF.concat("cola", "colb") + self.assertEqual("CONCAT(cola, colb)", col_str.sql()) + col = SF.concat(SF.col("cola"), SF.col("colb")) + self.assertEqual("CONCAT(cola, colb)", col.sql()) + col_single = SF.concat("cola") + self.assertEqual("CONCAT(cola)", col_single.sql()) + + def test_array_position(self): + col_str = SF.array_position("cola", SF.col("colb")) + self.assertEqual("ARRAY_POSITION(cola, colb)", col_str.sql()) + col = SF.array_position(SF.col("cola"), SF.col("colb")) + self.assertEqual("ARRAY_POSITION(cola, colb)", col.sql()) + col_lit = SF.array_position("cola", "test") + self.assertEqual("ARRAY_POSITION(cola, 'test')", col_lit) + + def test_element_at(self): + col_str = SF.element_at("cola", SF.col("colb")) + self.assertEqual("ELEMENT_AT(cola, colb)", col_str.sql()) + col = SF.element_at(SF.col("cola"), SF.col("colb")) + self.assertEqual("ELEMENT_AT(cola, colb)", col.sql()) + col_lit = SF.element_at("cola", "test") + self.assertEqual("ELEMENT_AT(cola, 'test')", col_lit) + + def test_array_remove(self): + col_str = SF.array_remove("cola", SF.col("colb")) + self.assertEqual("ARRAY_REMOVE(cola, colb)", col_str.sql()) + col = SF.array_remove(SF.col("cola"), SF.col("colb")) + self.assertEqual("ARRAY_REMOVE(cola, colb)", col.sql()) + col_lit = SF.array_remove("cola", "test") + self.assertEqual("ARRAY_REMOVE(cola, 'test')", col_lit) + + def test_array_distinct(self): + col_str = SF.array_distinct("cola") + self.assertEqual("ARRAY_DISTINCT(cola)", col_str.sql()) + col = SF.array_distinct(SF.col("cola")) + self.assertEqual("ARRAY_DISTINCT(cola)", col.sql()) + + def test_array_intersect(self): + col_str = SF.array_intersect("cola", "colb") + self.assertEqual("ARRAY_INTERSECT(cola, colb)", col_str.sql()) + col = SF.array_intersect(SF.col("cola"), SF.col("colb")) + self.assertEqual("ARRAY_INTERSECT(cola, colb)", col.sql()) + + def test_array_union(self): + col_str = SF.array_union("cola", "colb") + self.assertEqual("ARRAY_UNION(cola, colb)", col_str.sql()) + col = SF.array_union(SF.col("cola"), SF.col("colb")) + self.assertEqual("ARRAY_UNION(cola, colb)", col.sql()) + + def test_array_except(self): + col_str = SF.array_except("cola", "colb") + self.assertEqual("ARRAY_EXCEPT(cola, colb)", col_str.sql()) + col = SF.array_except(SF.col("cola"), SF.col("colb")) + self.assertEqual("ARRAY_EXCEPT(cola, colb)", col.sql()) + + def test_explode(self): + col_str = SF.explode("cola") + self.assertEqual("EXPLODE(cola)", col_str.sql()) + col = SF.explode(SF.col("cola")) + self.assertEqual("EXPLODE(cola)", col.sql()) + + def test_pos_explode(self): + col_str = SF.posexplode("cola") + self.assertEqual("POSEXPLODE(cola)", col_str.sql()) + col = SF.posexplode(SF.col("cola")) + self.assertEqual("POSEXPLODE(cola)", col.sql()) + + def test_explode_outer(self): + col_str = SF.explode_outer("cola") + self.assertEqual("EXPLODE_OUTER(cola)", col_str.sql()) + col = SF.explode_outer(SF.col("cola")) + self.assertEqual("EXPLODE_OUTER(cola)", col.sql()) + + def test_posexplode_outer(self): + col_str = SF.posexplode_outer("cola") + self.assertEqual("POSEXPLODE_OUTER(cola)", col_str.sql()) + col = SF.posexplode_outer(SF.col("cola")) + self.assertEqual("POSEXPLODE_OUTER(cola)", col.sql()) + + def test_get_json_object(self): + col_str = SF.get_json_object("cola", "$.f1") + self.assertEqual("GET_JSON_OBJECT(cola, '$.f1')", col_str.sql()) + col = SF.get_json_object(SF.col("cola"), "$.f1") + self.assertEqual("GET_JSON_OBJECT(cola, '$.f1')", col.sql()) + + def test_json_tuple(self): + col_str = SF.json_tuple("cola", "f1", "f2") + self.assertEqual("JSON_TUPLE(cola, 'f1', 'f2')", col_str.sql()) + col = SF.json_tuple(SF.col("cola"), "f1", "f2") + self.assertEqual("JSON_TUPLE(cola, 'f1', 'f2')", col.sql()) + + def test_from_json(self): + col_str = SF.from_json("cola", "cola INT", dict(timestampFormat="dd/MM/yyyy")) + self.assertEqual("FROM_JSON(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql()) + col = SF.from_json(SF.col("cola"), "cola INT", dict(timestampFormat="dd/MM/yyyy")) + self.assertEqual("FROM_JSON(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col.sql()) + col_no_option = SF.from_json("cola", "cola INT") + self.assertEqual("FROM_JSON(cola, 'cola INT')", col_no_option.sql()) + + def test_to_json(self): + col_str = SF.to_json("cola", dict(timestampFormat="dd/MM/yyyy")) + self.assertEqual("TO_JSON(cola, MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql()) + col = SF.to_json(SF.col("cola"), dict(timestampFormat="dd/MM/yyyy")) + self.assertEqual("TO_JSON(cola, MAP('timestampFormat', 'dd/MM/yyyy'))", col.sql()) + col_no_option = SF.to_json("cola") + self.assertEqual("TO_JSON(cola)", col_no_option.sql()) + + def test_schema_of_json(self): + col_str = SF.schema_of_json("cola", dict(timestampFormat="dd/MM/yyyy")) + self.assertEqual("SCHEMA_OF_JSON(cola, MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql()) + col = SF.schema_of_json(SF.col("cola"), dict(timestampFormat="dd/MM/yyyy")) + self.assertEqual("SCHEMA_OF_JSON(cola, MAP('timestampFormat', 'dd/MM/yyyy'))", col.sql()) + col_no_option = SF.schema_of_json("cola") + self.assertEqual("SCHEMA_OF_JSON(cola)", col_no_option.sql()) + + def test_schema_of_csv(self): + col_str = SF.schema_of_csv("cola", dict(timestampFormat="dd/MM/yyyy")) + self.assertEqual("SCHEMA_OF_CSV(cola, MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql()) + col = SF.schema_of_csv(SF.col("cola"), dict(timestampFormat="dd/MM/yyyy")) + self.assertEqual("SCHEMA_OF_CSV(cola, MAP('timestampFormat', 'dd/MM/yyyy'))", col.sql()) + col_no_option = SF.schema_of_csv("cola") + self.assertEqual("SCHEMA_OF_CSV(cola)", col_no_option.sql()) + + def test_to_csv(self): + col_str = SF.to_csv("cola", dict(timestampFormat="dd/MM/yyyy")) + self.assertEqual("TO_CSV(cola, MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql()) + col = SF.to_csv(SF.col("cola"), dict(timestampFormat="dd/MM/yyyy")) + self.assertEqual("TO_CSV(cola, MAP('timestampFormat', 'dd/MM/yyyy'))", col.sql()) + col_no_option = SF.to_csv("cola") + self.assertEqual("TO_CSV(cola)", col_no_option.sql()) + + def test_size(self): + col_str = SF.size("cola") + self.assertEqual("SIZE(cola)", col_str.sql()) + col = SF.size(SF.col("cola")) + self.assertEqual("SIZE(cola)", col.sql()) + + def test_array_min(self): + col_str = SF.array_min("cola") + self.assertEqual("ARRAY_MIN(cola)", col_str.sql()) + col = SF.array_min(SF.col("cola")) + self.assertEqual("ARRAY_MIN(cola)", col.sql()) + + def test_array_max(self): + col_str = SF.array_max("cola") + self.assertEqual("ARRAY_MAX(cola)", col_str.sql()) + col = SF.array_max(SF.col("cola")) + self.assertEqual("ARRAY_MAX(cola)", col.sql()) + + def test_sort_array(self): + col_str = SF.sort_array("cola", False) + self.assertEqual("SORT_ARRAY(cola, FALSE)", col_str.sql()) + col = SF.sort_array(SF.col("cola"), False) + self.assertEqual("SORT_ARRAY(cola, FALSE)", col.sql()) + col_no_sort = SF.sort_array("cola") + self.assertEqual("SORT_ARRAY(cola)", col_no_sort.sql()) + + def test_array_sort(self): + col_str = SF.array_sort("cola") + self.assertEqual("ARRAY_SORT(cola)", col_str.sql()) + col = SF.array_sort(SF.col("cola")) + self.assertEqual("ARRAY_SORT(cola)", col.sql()) + + def test_reverse(self): + col_str = SF.reverse("cola") + self.assertEqual("REVERSE(cola)", col_str.sql()) + col = SF.reverse(SF.col("cola")) + self.assertEqual("REVERSE(cola)", col.sql()) + + def test_flatten(self): + col_str = SF.flatten("cola") + self.assertEqual("FLATTEN(cola)", col_str.sql()) + col = SF.flatten(SF.col("cola")) + self.assertEqual("FLATTEN(cola)", col.sql()) + + def test_map_keys(self): + col_str = SF.map_keys("cola") + self.assertEqual("MAP_KEYS(cola)", col_str.sql()) + col = SF.map_keys(SF.col("cola")) + self.assertEqual("MAP_KEYS(cola)", col.sql()) + + def test_map_values(self): + col_str = SF.map_values("cola") + self.assertEqual("MAP_VALUES(cola)", col_str.sql()) + col = SF.map_values(SF.col("cola")) + self.assertEqual("MAP_VALUES(cola)", col.sql()) + + def test_map_entries(self): + col_str = SF.map_entries("cola") + self.assertEqual("MAP_ENTRIES(cola)", col_str.sql()) + col = SF.map_entries(SF.col("cola")) + self.assertEqual("MAP_ENTRIES(cola)", col.sql()) + + def test_map_from_entries(self): + col_str = SF.map_from_entries("cola") + self.assertEqual("MAP_FROM_ENTRIES(cola)", col_str.sql()) + col = SF.map_from_entries(SF.col("cola")) + self.assertEqual("MAP_FROM_ENTRIES(cola)", col.sql()) + + def test_array_repeat(self): + col_str = SF.array_repeat("cola", 2) + self.assertEqual("ARRAY_REPEAT(cola, 2)", col_str.sql()) + col = SF.array_repeat(SF.col("cola"), 2) + self.assertEqual("ARRAY_REPEAT(cola, 2)", col.sql()) + + def test_array_zip(self): + col_str = SF.array_zip("cola", "colb") + self.assertEqual("ARRAY_ZIP(cola, colb)", col_str.sql()) + col = SF.array_zip(SF.col("cola"), SF.col("colb")) + self.assertEqual("ARRAY_ZIP(cola, colb)", col.sql()) + col_single = SF.array_zip("cola") + self.assertEqual("ARRAY_ZIP(cola)", col_single.sql()) + + def test_map_concat(self): + col_str = SF.map_concat("cola", "colb") + self.assertEqual("MAP_CONCAT(cola, colb)", col_str.sql()) + col = SF.map_concat(SF.col("cola"), SF.col("colb")) + self.assertEqual("MAP_CONCAT(cola, colb)", col.sql()) + col_single = SF.map_concat("cola") + self.assertEqual("MAP_CONCAT(cola)", col_single.sql()) + + def test_sequence(self): + col_str = SF.sequence("cola", "colb", "colc") + self.assertEqual("SEQUENCE(cola, colb, colc)", col_str.sql()) + col = SF.sequence(SF.col("cola"), SF.col("colb"), SF.col("colc")) + self.assertEqual("SEQUENCE(cola, colb, colc)", col.sql()) + col_no_step = SF.sequence("cola", "colb") + self.assertEqual("SEQUENCE(cola, colb)", col_no_step.sql()) + + def test_from_csv(self): + col_str = SF.from_csv("cola", "cola INT", dict(timestampFormat="dd/MM/yyyy")) + self.assertEqual("FROM_CSV(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql()) + col = SF.from_csv(SF.col("cola"), "cola INT", dict(timestampFormat="dd/MM/yyyy")) + self.assertEqual("FROM_CSV(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col.sql()) + col_no_option = SF.from_csv("cola", "cola INT") + self.assertEqual("FROM_CSV(cola, 'cola INT')", col_no_option.sql()) + + def test_aggregate(self): + col_str = SF.aggregate("cola", SF.lit(0), lambda acc, x: acc + x, lambda acc: acc * 2) + self.assertEqual("AGGREGATE(cola, 0, (acc, x) -> acc + x, acc -> acc * 2)", col_str.sql()) + col = SF.aggregate(SF.col("cola"), SF.lit(0), lambda acc, x: acc + x, lambda acc: acc * 2) + self.assertEqual("AGGREGATE(cola, 0, (acc, x) -> acc + x, acc -> acc * 2)", col.sql()) + col_no_finish = SF.aggregate("cola", SF.lit(0), lambda acc, x: acc + x) + self.assertEqual("AGGREGATE(cola, 0, (acc, x) -> acc + x)", col_no_finish.sql()) + col_custom_names = SF.aggregate( + "cola", + SF.lit(0), + lambda accumulator, target: accumulator + target, + lambda accumulator: accumulator * 2, + "accumulator", + "target", + ) + self.assertEqual( + "AGGREGATE(cola, 0, (accumulator, target) -> accumulator + target, accumulator -> accumulator * 2)", + col_custom_names.sql(), + ) + + def test_transform(self): + col_str = SF.transform("cola", lambda x: x * 2) + self.assertEqual("TRANSFORM(cola, x -> x * 2)", col_str.sql()) + col = SF.transform(SF.col("cola"), lambda x, i: x * i) + self.assertEqual("TRANSFORM(cola, (x, i) -> x * i)", col.sql()) + col_custom_names = SF.transform("cola", lambda target, row_count: target * row_count, "target", "row_count") + + self.assertEqual("TRANSFORM(cola, (target, row_count) -> target * row_count)", col_custom_names.sql()) + + def test_exists(self): + col_str = SF.exists("cola", lambda x: x % 2 == 0) + self.assertEqual("EXISTS(cola, x -> x % 2 = 0)", col_str.sql()) + col = SF.exists(SF.col("cola"), lambda x: x % 2 == 0) + self.assertEqual("EXISTS(cola, x -> x % 2 = 0)", col.sql()) + col_custom_name = SF.exists("cola", lambda target: target > 0, "target") + self.assertEqual("EXISTS(cola, target -> target > 0)", col_custom_name.sql()) + + def test_forall(self): + col_str = SF.forall("cola", lambda x: x.rlike("foo")) + self.assertEqual("FORALL(cola, x -> x RLIKE 'foo')", col_str.sql()) + col = SF.forall(SF.col("cola"), lambda x: x.rlike("foo")) + self.assertEqual("FORALL(cola, x -> x RLIKE 'foo')", col.sql()) + col_custom_name = SF.forall("cola", lambda target: target.rlike("foo"), "target") + self.assertEqual("FORALL(cola, target -> target RLIKE 'foo')", col_custom_name.sql()) + + def test_filter(self): + col_str = SF.filter("cola", lambda x: SF.month(SF.to_date(x)) > SF.lit(6)) + self.assertEqual("FILTER(cola, x -> MONTH(TO_DATE(x)) > 6)", col_str.sql()) + col = SF.filter(SF.col("cola"), lambda x, i: SF.month(SF.to_date(x)) > SF.lit(i)) + self.assertEqual("FILTER(cola, (x, i) -> MONTH(TO_DATE(x)) > i)", col.sql()) + col_custom_names = SF.filter( + "cola", lambda target, row_count: SF.month(SF.to_date(target)) > SF.lit(row_count), "target", "row_count" + ) + + self.assertEqual( + "FILTER(cola, (target, row_count) -> MONTH(TO_DATE(target)) > row_count)", col_custom_names.sql() + ) + + def test_zip_with(self): + col_str = SF.zip_with("cola", "colb", lambda x, y: SF.concat_ws("_", x, y)) + self.assertEqual("ZIP_WITH(cola, colb, (x, y) -> CONCAT_WS('_', x, y))", col_str.sql()) + col = SF.zip_with(SF.col("cola"), SF.col("colb"), lambda x, y: SF.concat_ws("_", x, y)) + self.assertEqual("ZIP_WITH(cola, colb, (x, y) -> CONCAT_WS('_', x, y))", col.sql()) + col_custom_names = SF.zip_with("cola", "colb", lambda l, r: SF.concat_ws("_", l, r), "l", "r") + self.assertEqual("ZIP_WITH(cola, colb, (l, r) -> CONCAT_WS('_', l, r))", col_custom_names.sql()) + + def test_transform_keys(self): + col_str = SF.transform_keys("cola", lambda k, v: SF.upper(k)) + self.assertEqual("TRANSFORM_KEYS(cola, (k, v) -> UPPER(k))", col_str.sql()) + col = SF.transform_keys(SF.col("cola"), lambda k, v: SF.upper(k)) + self.assertEqual("TRANSFORM_KEYS(cola, (k, v) -> UPPER(k))", col.sql()) + col_custom_names = SF.transform_keys("cola", lambda key, _: SF.upper(key), "key", "_") + self.assertEqual("TRANSFORM_KEYS(cola, (key, _) -> UPPER(key))", col_custom_names.sql()) + + def test_transform_values(self): + col_str = SF.transform_values("cola", lambda k, v: SF.upper(v)) + self.assertEqual("TRANSFORM_VALUES(cola, (k, v) -> UPPER(v))", col_str.sql()) + col = SF.transform_values(SF.col("cola"), lambda k, v: SF.upper(v)) + self.assertEqual("TRANSFORM_VALUES(cola, (k, v) -> UPPER(v))", col.sql()) + col_custom_names = SF.transform_values("cola", lambda _, value: SF.upper(value), "_", "value") + self.assertEqual("TRANSFORM_VALUES(cola, (_, value) -> UPPER(value))", col_custom_names.sql()) + + def test_map_filter(self): + col_str = SF.map_filter("cola", lambda k, v: k > v) + self.assertEqual("MAP_FILTER(cola, (k, v) -> k > v)", col_str.sql()) + col = SF.map_filter(SF.col("cola"), lambda k, v: k > v) + self.assertEqual("MAP_FILTER(cola, (k, v) -> k > v)", col.sql()) + col_custom_names = SF.map_filter("cola", lambda key, value: key > value, "key", "value") + self.assertEqual("MAP_FILTER(cola, (key, value) -> key > value)", col_custom_names.sql()) diff --git a/tests/dataframe/unit/test_session.py b/tests/dataframe/unit/test_session.py new file mode 100644 index 0000000..158dcec --- /dev/null +++ b/tests/dataframe/unit/test_session.py @@ -0,0 +1,114 @@ +from unittest import mock + +import sqlglot +from sqlglot.dataframe.sql import functions as F +from sqlglot.dataframe.sql import types +from sqlglot.dataframe.sql.session import SparkSession +from sqlglot.schema import MappingSchema +from tests.dataframe.unit.dataframe_sql_validator import DataFrameSQLValidator + + +class TestDataframeSession(DataFrameSQLValidator): + def test_cdf_one_row(self): + df = self.spark.createDataFrame([[1, 2]], ["cola", "colb"]) + expected = "SELECT `a2`.`cola` AS `cola`, `a2`.`colb` AS `colb` FROM (VALUES (1, 2)) AS `a2`(`cola`, `colb`)" + self.compare_sql(df, expected) + + def test_cdf_multiple_rows(self): + df = self.spark.createDataFrame([[1, 2], [3, 4], [None, 6]], ["cola", "colb"]) + expected = "SELECT `a2`.`cola` AS `cola`, `a2`.`colb` AS `colb` FROM (VALUES (1, 2), (3, 4), (NULL, 6)) AS `a2`(`cola`, `colb`)" + self.compare_sql(df, expected) + + def test_cdf_no_schema(self): + df = self.spark.createDataFrame([[1, 2], [3, 4], [None, 6]]) + expected = ( + "SELECT `a2`.`_1` AS `_1`, `a2`.`_2` AS `_2` FROM (VALUES (1, 2), (3, 4), (NULL, 6)) AS `a2`(`_1`, `_2`)" + ) + self.compare_sql(df, expected) + + def test_cdf_row_mixed_primitives(self): + df = self.spark.createDataFrame([[1, 10.1, "test", False, None]]) + expected = "SELECT `a2`.`_1` AS `_1`, `a2`.`_2` AS `_2`, `a2`.`_3` AS `_3`, `a2`.`_4` AS `_4`, `a2`.`_5` AS `_5` FROM (VALUES (1, 10.1, 'test', FALSE, NULL)) AS `a2`(`_1`, `_2`, `_3`, `_4`, `_5`)" + self.compare_sql(df, expected) + + def test_cdf_dict_rows(self): + df = self.spark.createDataFrame([{"cola": 1, "colb": "test"}, {"cola": 2, "colb": "test2"}]) + expected = "SELECT `a2`.`cola` AS `cola`, `a2`.`colb` AS `colb` FROM (VALUES (1, 'test'), (2, 'test2')) AS `a2`(`cola`, `colb`)" + self.compare_sql(df, expected) + + def test_cdf_str_schema(self): + df = self.spark.createDataFrame([[1, "test"]], "cola: INT, colb: STRING") + expected = "SELECT CAST(`a2`.`cola` AS INT) AS `cola`, CAST(`a2`.`colb` AS STRING) AS `colb` FROM (VALUES (1, 'test')) AS `a2`(`cola`, `colb`)" + self.compare_sql(df, expected) + + def test_typed_schema_basic(self): + schema = types.StructType( + [ + types.StructField("cola", types.IntegerType()), + types.StructField("colb", types.StringType()), + ] + ) + df = self.spark.createDataFrame([[1, "test"]], schema) + expected = "SELECT CAST(`a2`.`cola` AS int) AS `cola`, CAST(`a2`.`colb` AS string) AS `colb` FROM (VALUES (1, 'test')) AS `a2`(`cola`, `colb`)" + self.compare_sql(df, expected) + + def test_typed_schema_nested(self): + schema = types.StructType( + [ + types.StructField( + "cola", + types.StructType( + [ + types.StructField("sub_cola", types.IntegerType()), + types.StructField("sub_colb", types.StringType()), + ] + ), + ) + ] + ) + df = self.spark.createDataFrame([[{"sub_cola": 1, "sub_colb": "test"}]], schema) + expected = "SELECT CAST(`a2`.`cola` AS struct) AS `cola` FROM (VALUES (STRUCT(1 AS `sub_cola`, 'test' AS `sub_colb`))) AS `a2`(`cola`)" + self.compare_sql(df, expected) + + @mock.patch("sqlglot.schema", MappingSchema()) + def test_sql_select_only(self): + # TODO: Do exact matches once CTE names are deterministic + query = "SELECT cola, colb FROM table" + sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"}) + df = self.spark.sql(query) + self.assertIn( + "SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`", df.sql(pretty=False) + ) + + @mock.patch("sqlglot.schema", MappingSchema()) + def test_sql_with_aggs(self): + # TODO: Do exact matches once CTE names are deterministic + query = "SELECT cola, colb FROM table" + sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"}) + df = self.spark.sql(query).groupBy(F.col("cola")).agg(F.sum("colb")) + result = df.sql(pretty=False, optimize=False)[0] + self.assertIn("SELECT cola, colb FROM table", result) + self.assertIn("SUM(colb)", result) + self.assertIn("GROUP BY cola", result) + + @mock.patch("sqlglot.schema", MappingSchema()) + def test_sql_create(self): + query = "CREATE TABLE new_table AS WITH t1 AS (SELECT cola, colb FROM table) SELECT cola, colb, FROM t1" + sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"}) + df = self.spark.sql(query) + expected = "CREATE TABLE new_table AS SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`" + self.compare_sql(df, expected) + + @mock.patch("sqlglot.schema", MappingSchema()) + def test_sql_insert(self): + query = "WITH t1 AS (SELECT cola, colb FROM table) INSERT INTO new_table SELECT cola, colb FROM t1" + sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"}) + df = self.spark.sql(query) + expected = ( + "INSERT INTO new_table SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`" + ) + self.compare_sql(df, expected) + + def test_session_create_builder_patterns(self): + spark = SparkSession() + self.assertEqual(spark.builder.appName("abc").getOrCreate(), spark) diff --git a/tests/dataframe/unit/test_types.py b/tests/dataframe/unit/test_types.py new file mode 100644 index 0000000..1f6c5dc --- /dev/null +++ b/tests/dataframe/unit/test_types.py @@ -0,0 +1,70 @@ +import unittest + +from sqlglot.dataframe.sql import types + + +class TestDataframeTypes(unittest.TestCase): + def test_string(self): + self.assertEqual("string", types.StringType().simpleString()) + + def test_char(self): + self.assertEqual("char(100)", types.CharType(100).simpleString()) + + def test_varchar(self): + self.assertEqual("varchar(65)", types.VarcharType(65).simpleString()) + + def test_binary(self): + self.assertEqual("binary", types.BinaryType().simpleString()) + + def test_boolean(self): + self.assertEqual("boolean", types.BooleanType().simpleString()) + + def test_date(self): + self.assertEqual("date", types.DateType().simpleString()) + + def test_timestamp(self): + self.assertEqual("timestamp", types.TimestampType().simpleString()) + + def test_timestamp_ntz(self): + self.assertEqual("timestamp_ntz", types.TimestampNTZType().simpleString()) + + def test_decimal(self): + self.assertEqual("decimal(10, 3)", types.DecimalType(10, 3).simpleString()) + + def test_double(self): + self.assertEqual("double", types.DoubleType().simpleString()) + + def test_float(self): + self.assertEqual("float", types.FloatType().simpleString()) + + def test_byte(self): + self.assertEqual("tinyint", types.ByteType().simpleString()) + + def test_integer(self): + self.assertEqual("int", types.IntegerType().simpleString()) + + def test_long(self): + self.assertEqual("bigint", types.LongType().simpleString()) + + def test_short(self): + self.assertEqual("smallint", types.ShortType().simpleString()) + + def test_array(self): + self.assertEqual("array", types.ArrayType(types.IntegerType()).simpleString()) + + def test_map(self): + self.assertEqual("map", types.MapType(types.IntegerType(), types.StringType()).simpleString()) + + def test_struct_field(self): + self.assertEqual("cola:int", types.StructField("cola", types.IntegerType()).simpleString()) + + def test_struct_type(self): + self.assertEqual( + "struct", + types.StructType( + [ + types.StructField("cola", types.IntegerType()), + types.StructField("colb", types.StringType()), + ] + ).simpleString(), + ) diff --git a/tests/dataframe/unit/test_window.py b/tests/dataframe/unit/test_window.py new file mode 100644 index 0000000..eea4582 --- /dev/null +++ b/tests/dataframe/unit/test_window.py @@ -0,0 +1,60 @@ +import unittest + +from sqlglot.dataframe.sql import functions as F +from sqlglot.dataframe.sql.window import Window, WindowSpec + + +class TestDataframeWindow(unittest.TestCase): + def test_window_spec_partition_by(self): + partition_by = WindowSpec().partitionBy(F.col("cola"), F.col("colb")) + self.assertEqual("OVER (PARTITION BY cola, colb)", partition_by.sql()) + + def test_window_spec_order_by(self): + order_by = WindowSpec().orderBy("cola", "colb") + self.assertEqual("OVER (ORDER BY cola, colb)", order_by.sql()) + + def test_window_spec_rows_between(self): + rows_between = WindowSpec().rowsBetween(3, 5) + self.assertEqual("OVER ( ROWS BETWEEN 3 PRECEDING AND 5 FOLLOWING)", rows_between.sql()) + + def test_window_spec_range_between(self): + range_between = WindowSpec().rangeBetween(3, 5) + self.assertEqual("OVER ( RANGE BETWEEN 3 PRECEDING AND 5 FOLLOWING)", range_between.sql()) + + def test_window_partition_by(self): + partition_by = Window.partitionBy(F.col("cola"), F.col("colb")) + self.assertEqual("OVER (PARTITION BY cola, colb)", partition_by.sql()) + + def test_window_order_by(self): + order_by = Window.orderBy("cola", "colb") + self.assertEqual("OVER (ORDER BY cola, colb)", order_by.sql()) + + def test_window_rows_between(self): + rows_between = Window.rowsBetween(3, 5) + self.assertEqual("OVER ( ROWS BETWEEN 3 PRECEDING AND 5 FOLLOWING)", rows_between.sql()) + + def test_window_range_between(self): + range_between = Window.rangeBetween(3, 5) + self.assertEqual("OVER ( RANGE BETWEEN 3 PRECEDING AND 5 FOLLOWING)", range_between.sql()) + + def test_window_rows_unbounded(self): + rows_between_unbounded_start = Window.rowsBetween(Window.unboundedPreceding, 2) + self.assertEqual("OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)", rows_between_unbounded_start.sql()) + rows_between_unbounded_end = Window.rowsBetween(1, Window.unboundedFollowing) + self.assertEqual("OVER ( ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)", rows_between_unbounded_end.sql()) + rows_between_unbounded_both = Window.rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing) + self.assertEqual( + "OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)", rows_between_unbounded_both.sql() + ) + + def test_window_range_unbounded(self): + range_between_unbounded_start = Window.rangeBetween(Window.unboundedPreceding, 2) + self.assertEqual( + "OVER ( RANGE BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)", range_between_unbounded_start.sql() + ) + range_between_unbounded_end = Window.rangeBetween(1, Window.unboundedFollowing) + self.assertEqual("OVER ( RANGE BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)", range_between_unbounded_end.sql()) + range_between_unbounded_both = Window.rangeBetween(Window.unboundedPreceding, Window.unboundedFollowing) + self.assertEqual( + "OVER ( RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)", range_between_unbounded_both.sql() + ) diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py index a1e1262..e1524e9 100644 --- a/tests/dialects/test_dialect.py +++ b/tests/dialects/test_dialect.py @@ -694,29 +694,6 @@ class TestDialect(Validator): }, ) - # https://dev.mysql.com/doc/refman/8.0/en/join.html - # https://www.postgresql.org/docs/current/queries-table-expressions.html - def test_joined_tables(self): - self.validate_identity("SELECT * FROM (tbl1 LEFT JOIN tbl2 ON 1 = 1)") - self.validate_identity("SELECT * FROM (tbl1 JOIN tbl2 JOIN tbl3)") - self.validate_identity("SELECT * FROM (tbl1 JOIN (tbl2 JOIN tbl3) ON bla = foo)") - self.validate_identity("SELECT * FROM (tbl1 JOIN LATERAL (SELECT * FROM bla) AS tbl)") - - self.validate_all( - "SELECT * FROM (tbl1 LEFT JOIN tbl2 ON 1 = 1)", - write={ - "postgres": "SELECT * FROM (tbl1 LEFT JOIN tbl2 ON 1 = 1)", - "mysql": "SELECT * FROM (tbl1 LEFT JOIN tbl2 ON 1 = 1)", - }, - ) - self.validate_all( - "SELECT * FROM (tbl1 JOIN LATERAL (SELECT * FROM bla) AS tbl)", - write={ - "postgres": "SELECT * FROM (tbl1 JOIN LATERAL (SELECT * FROM bla) AS tbl)", - "mysql": "SELECT * FROM (tbl1 JOIN LATERAL (SELECT * FROM bla) AS tbl)", - }, - ) - def test_lateral_subquery(self): self.validate_identity( "SELECT art FROM tbl1 INNER JOIN LATERAL (SELECT art FROM tbl2) AS tbl2 ON tbl1.art = tbl2.art" @@ -856,7 +833,7 @@ class TestDialect(Validator): "postgres": "x ILIKE '%y'", "presto": "LOWER(x) LIKE '%y'", "snowflake": "x ILIKE '%y'", - "spark": "LOWER(x) LIKE '%y'", + "spark": "x ILIKE '%y'", "sqlite": "LOWER(x) LIKE '%y'", "starrocks": "LOWER(x) LIKE '%y'", "trino": "LOWER(x) LIKE '%y'", diff --git a/tests/dialects/test_duckdb.py b/tests/dialects/test_duckdb.py index 298b3e9..625156b 100644 --- a/tests/dialects/test_duckdb.py +++ b/tests/dialects/test_duckdb.py @@ -48,7 +48,7 @@ class TestDuckDB(Validator): self.validate_all( "STRPTIME(x, '%y-%-m')", write={ - "bigquery": "STR_TO_TIME(x, '%y-%-m')", + "bigquery": "PARSE_TIMESTAMP('%y-%m', x)", "duckdb": "STRPTIME(x, '%y-%-m')", "presto": "DATE_PARSE(x, '%y-%c')", "hive": "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'yy-M')) AS TIMESTAMP)", @@ -63,6 +63,16 @@ class TestDuckDB(Validator): "hive": "CAST(x AS TIMESTAMP)", }, ) + self.validate_all( + "STRPTIME(x, '%-m/%-d/%y %-I:%M %p')", + write={ + "bigquery": "PARSE_TIMESTAMP('%m/%d/%y %I:%M %p', x)", + "duckdb": "STRPTIME(x, '%-m/%-d/%y %-I:%M %p')", + "presto": "DATE_PARSE(x, '%c/%e/%y %l:%i %p')", + "hive": "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'M/d/yy h:mm a')) AS TIMESTAMP)", + "spark": "TO_TIMESTAMP(x, 'M/d/yy h:mm a')", + }, + ) def test_duckdb(self): self.validate_all( @@ -268,6 +278,17 @@ class TestDuckDB(Validator): "spark": "MONTH('2021-03-01')", }, ) + self.validate_all( + "ARRAY_CONCAT(LIST_VALUE(1, 2), LIST_VALUE(3, 4))", + write={ + "duckdb": "ARRAY_CONCAT(LIST_VALUE(1, 2), LIST_VALUE(3, 4))", + "presto": "CONCAT(ARRAY[1, 2], ARRAY[3, 4])", + "hive": "CONCAT(ARRAY(1, 2), ARRAY(3, 4))", + "spark": "CONCAT(ARRAY(1, 2), ARRAY(3, 4))", + "snowflake": "ARRAY_CAT([1, 2], [3, 4])", + "bigquery": "ARRAY_CONCAT([1, 2], [3, 4])", + }, + ) with self.assertRaises(UnsupportedError): transpile( diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py index 723e27c..a25871c 100644 --- a/tests/dialects/test_mysql.py +++ b/tests/dialects/test_mysql.py @@ -31,6 +31,24 @@ class TestMySQL(Validator): "mysql": "_utf8mb4 'hola'", }, ) + self.validate_all( + "N 'some text'", + read={ + "mysql": "N'some text'", + }, + write={ + "mysql": "N 'some text'", + }, + ) + self.validate_all( + "_latin1 x'4D7953514C'", + read={ + "mysql": "_latin1 X'4D7953514C'", + }, + write={ + "mysql": "_latin1 x'4D7953514C'", + }, + ) def test_hexadecimal_literal(self): self.validate_all( diff --git a/tests/dialects/test_postgres.py b/tests/dialects/test_postgres.py index 4b8f3c3..35141e2 100644 --- a/tests/dialects/test_postgres.py +++ b/tests/dialects/test_postgres.py @@ -69,6 +69,8 @@ class TestPostgres(Validator): self.validate_identity("SELECT TRIM(LEADING 'bla' FROM ' XXX ' COLLATE utf8_bin)") self.validate_identity("SELECT TO_TIMESTAMP(1284352323.5), TO_TIMESTAMP('05 Dec 2000', 'DD Mon YYYY')") self.validate_identity("COMMENT ON TABLE mytable IS 'this'") + self.validate_identity("SELECT e'\\xDEADBEEF'") + self.validate_identity("SELECT CAST(e'\\176' AS BYTEA)") self.validate_all( "CREATE TABLE x (a UUID, b BYTEA)", @@ -204,3 +206,11 @@ class TestPostgres(Validator): """'{"a":[1,2,3],"b":[4,5,6]}'::json#>>'{a,2}'""", write={"postgres": """CAST('{"a":[1,2,3],"b":[4,5,6]}' AS JSON)#>>'{a,2}'"""}, ) + self.validate_all( + "SELECT $$a$$", + write={"postgres": "SELECT 'a'"}, + ) + self.validate_all( + "SELECT $$Dianne's horse$$", + write={"postgres": "SELECT 'Dianne''s horse'"}, + ) diff --git a/tests/dialects/test_presto.py b/tests/dialects/test_presto.py index 10c9d35..098ad2b 100644 --- a/tests/dialects/test_presto.py +++ b/tests/dialects/test_presto.py @@ -321,7 +321,7 @@ class TestPresto(Validator): "duckdb": "SELECT APPROX_COUNT_DISTINCT(a) FROM foo", "presto": "SELECT APPROX_DISTINCT(a, 0.1) FROM foo", "hive": "SELECT APPROX_COUNT_DISTINCT(a) FROM foo", - "spark": "SELECT APPROX_COUNT_DISTINCT(a) FROM foo", + "spark": "SELECT APPROX_COUNT_DISTINCT(a, 0.1) FROM foo", }, ) self.validate_all( @@ -329,7 +329,7 @@ class TestPresto(Validator): write={ "presto": "SELECT APPROX_DISTINCT(a, 0.1) FROM foo", "hive": UnsupportedError, - "spark": UnsupportedError, + "spark": "SELECT APPROX_COUNT_DISTINCT(a, 0.1) FROM foo", }, ) self.validate_all( diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py index 8a33e2d..159b643 100644 --- a/tests/dialects/test_snowflake.py +++ b/tests/dialects/test_snowflake.py @@ -65,7 +65,7 @@ class TestSnowflake(Validator): self.validate_all( "SELECT TO_TIMESTAMP('2013-04-05 01:02:03')", write={ - "bigquery": "SELECT STR_TO_TIME('2013-04-05 01:02:03', '%Y-%m-%d %H:%M:%S')", + "bigquery": "SELECT PARSE_TIMESTAMP('%Y-%m-%d %H:%M:%S', '2013-04-05 01:02:03')", "snowflake": "SELECT TO_TIMESTAMP('2013-04-05 01:02:03', 'yyyy-mm-dd hh24:mi:ss')", "spark": "SELECT TO_TIMESTAMP('2013-04-05 01:02:03', 'yyyy-MM-dd HH:mm:ss')", }, @@ -73,16 +73,17 @@ class TestSnowflake(Validator): self.validate_all( "SELECT TO_TIMESTAMP('04/05/2013 01:02:03', 'mm/dd/yyyy hh24:mi:ss')", read={ - "bigquery": "SELECT STR_TO_TIME('04/05/2013 01:02:03', '%m/%d/%Y %H:%M:%S')", + "bigquery": "SELECT PARSE_TIMESTAMP('%m/%d/%Y %H:%M:%S', '04/05/2013 01:02:03')", "duckdb": "SELECT STRPTIME('04/05/2013 01:02:03', '%m/%d/%Y %H:%M:%S')", "snowflake": "SELECT TO_TIMESTAMP('04/05/2013 01:02:03', 'mm/dd/yyyy hh24:mi:ss')", }, write={ - "bigquery": "SELECT STR_TO_TIME('04/05/2013 01:02:03', '%m/%d/%Y %H:%M:%S')", + "bigquery": "SELECT PARSE_TIMESTAMP('%m/%d/%Y %H:%M:%S', '04/05/2013 01:02:03')", "snowflake": "SELECT TO_TIMESTAMP('04/05/2013 01:02:03', 'mm/dd/yyyy hh24:mi:ss')", "spark": "SELECT TO_TIMESTAMP('04/05/2013 01:02:03', 'MM/dd/yyyy HH:mm:ss')", }, ) + self.validate_all( "SELECT IFF(TRUE, 'true', 'false')", write={ @@ -240,11 +241,25 @@ class TestSnowflake(Validator): }, ) self.validate_all( - "SELECT DATE_PART(month FROM a::DATETIME)", + "SELECT DATE_PART(month, a::DATETIME)", write={ "snowflake": "SELECT EXTRACT(month FROM CAST(a AS DATETIME))", }, ) + self.validate_all( + "SELECT DATE_PART(epoch_second, foo) as ddate from table_name", + write={ + "snowflake": "SELECT EXTRACT(epoch_second FROM CAST(foo AS TIMESTAMPNTZ)) AS ddate FROM table_name", + "presto": "SELECT TO_UNIXTIME(CAST(foo AS TIMESTAMP)) AS ddate FROM table_name", + }, + ) + self.validate_all( + "SELECT DATE_PART(epoch_milliseconds, foo) as ddate from table_name", + write={ + "snowflake": "SELECT EXTRACT(epoch_second FROM CAST(foo AS TIMESTAMPNTZ)) * 1000 AS ddate FROM table_name", + "presto": "SELECT TO_UNIXTIME(CAST(foo AS TIMESTAMP)) * 1000 AS ddate FROM table_name", + }, + ) def test_semi_structured_types(self): self.validate_identity("SELECT CAST(a AS VARIANT)") diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py index b061784..9a6bc36 100644 --- a/tests/dialects/test_tsql.py +++ b/tests/dialects/test_tsql.py @@ -45,3 +45,29 @@ class TestTSQL(Validator): "tsql": "CAST(x AS DATETIME2)", }, ) + + def test_charindex(self): + self.validate_all( + "CHARINDEX(x, y, 9)", + write={ + "spark": "LOCATE(x, y, 9)", + }, + ) + self.validate_all( + "CHARINDEX(x, y)", + write={ + "spark": "LOCATE(x, y)", + }, + ) + self.validate_all( + "CHARINDEX('sub', 'testsubstring', 3)", + write={ + "spark": "LOCATE('sub', 'testsubstring', 3)", + }, + ) + self.validate_all( + "CHARINDEX('sub', 'testsubstring')", + write={ + "spark": "LOCATE('sub', 'testsubstring')", + }, + ) diff --git a/tests/fixtures/identity.sql b/tests/fixtures/identity.sql index 57e51e0..67e4cab 100644 --- a/tests/fixtures/identity.sql +++ b/tests/fixtures/identity.sql @@ -513,6 +513,8 @@ ALTER TYPE electronic_mail RENAME TO email ANALYZE a.y DELETE FROM x WHERE y > 1 DELETE FROM y +DELETE FROM event USING sales WHERE event.eventid = sales.eventid +DELETE FROM event USING sales, USING bla WHERE event.eventid = sales.eventid DROP TABLE a DROP TABLE a.b DROP TABLE IF EXISTS a @@ -563,3 +565,8 @@ WITH a AS ((SELECT 1 AS b) UNION ALL (SELECT 1 AS b)) SELECT * FROM a SELECT (WITH x AS (SELECT 1 AS y) SELECT * FROM x) AS z SELECT ((SELECT 1) + 1) SELECT * FROM project.dataset.INFORMATION_SCHEMA.TABLES +SELECT * FROM (table1 AS t1 LEFT JOIN table2 AS t2 ON 1 = 1) +SELECT * FROM (tbl1 LEFT JOIN tbl2 ON 1 = 1) +SELECT * FROM (tbl1 JOIN tbl2 JOIN tbl3) +SELECT * FROM (tbl1 JOIN (tbl2 JOIN tbl3) ON bla = foo) +SELECT * FROM (tbl1 JOIN LATERAL (SELECT * FROM bla) AS tbl) diff --git a/tests/fixtures/optimizer/merge_subqueries.sql b/tests/fixtures/optimizer/merge_subqueries.sql index a82e1ed..4a3ad4b 100644 --- a/tests/fixtures/optimizer/merge_subqueries.sql +++ b/tests/fixtures/optimizer/merge_subqueries.sql @@ -287,3 +287,27 @@ SELECT FROM t1; SELECT x.a AS a, x.b AS b, ROW_NUMBER() OVER (PARTITION BY x.a ORDER BY x.a) AS row_num FROM x AS x; + +# title: Values Test +# dialect: spark +WITH t1 AS ( + SELECT + a1.cola + FROM + VALUES (1) AS a1(cola) +), t2 AS ( + SELECT + a2.cola + FROM + VALUES (1) AS a2(cola) +) +SELECT /*+ BROADCAST(t2) */ + t1.cola, + t2.cola, +FROM + t1 + JOIN + t2 + ON + t1.cola = t2.cola; +SELECT /*+ BROADCAST(a2) */ a1.cola AS cola, a2.cola AS cola FROM VALUES (1) AS a1(cola) JOIN VALUES (1) AS a2(cola) ON a1.cola = a2.cola; diff --git a/tests/fixtures/optimizer/pushdown_predicates.sql b/tests/fixtures/optimizer/pushdown_predicates.sql index ef591ec..dd318a2 100644 --- a/tests/fixtures/optimizer/pushdown_predicates.sql +++ b/tests/fixtures/optimizer/pushdown_predicates.sql @@ -33,3 +33,6 @@ SELECT x.a AS a FROM x AS x JOIN (SELECT * FROM y AS y WHERE y.a = 1) AS y ON y. with t1 as (SELECT x.a, x.b, ROW_NUMBER() OVER (PARTITION BY x.a ORDER BY x.a) as row_num FROM x) SELECT t1.a, t1.b FROM t1 WHERE row_num = 1; WITH t1 AS (SELECT x.a, x.b, ROW_NUMBER() OVER (PARTITION BY x.a ORDER BY x.a) AS row_num FROM x) SELECT t1.a, t1.b FROM t1 WHERE row_num = 1; + +WITH m AS (SELECT a, b FROM (VALUES (1, 2)) AS a1(a, b)), n AS (SELECT a, b FROM m WHERE m.a = 1), o AS (SELECT a, b FROM m WHERE m.a = 2) SELECT n.a, n.b, n.a, o.b FROM n FULL OUTER JOIN o ON n.a = o.a; +WITH m AS (SELECT a, b FROM (VALUES (1, 2)) AS a1(a, b)), n AS (SELECT a, b FROM m WHERE m.a = 1), o AS (SELECT a, b FROM m WHERE m.a = 2) SELECT n.a, n.b, n.a, o.b FROM n FULL OUTER JOIN o ON n.a = o.a; diff --git a/tests/fixtures/optimizer/pushdown_projections.sql b/tests/fixtures/optimizer/pushdown_projections.sql index b03ffab..ba4bf45 100644 --- a/tests/fixtures/optimizer/pushdown_projections.sql +++ b/tests/fixtures/optimizer/pushdown_projections.sql @@ -22,6 +22,9 @@ SELECT "_q_0".a AS a FROM (SELECT DISTINCT x.a AS a, x.b AS b FROM x AS x) AS "_ SELECT a FROM (SELECT a, b FROM x UNION ALL SELECT a, b FROM x); SELECT "_q_0".a AS a FROM (SELECT x.a AS a FROM x AS x UNION ALL SELECT x.a AS a FROM x AS x) AS "_q_0"; +WITH t1 AS (SELECT x.a AS a, x.b AS b FROM x UNION ALL SELECT z.b AS b, z.c AS c FROM z) SELECT a, b FROM t1; +WITH t1 AS (SELECT x.a AS a, x.b AS b FROM x AS x UNION ALL SELECT z.b AS b, z.c AS c FROM z AS z) SELECT t1.a AS a, t1.b AS b FROM t1; + SELECT a FROM (SELECT a, b FROM x UNION SELECT a, b FROM x); SELECT "_q_0".a AS a FROM (SELECT x.a AS a, x.b AS b FROM x AS x UNION SELECT x.a AS a, x.b AS b FROM x AS x) AS "_q_0"; diff --git a/tests/fixtures/optimizer/qualify_columns.sql b/tests/fixtures/optimizer/qualify_columns.sql index 83a3bf8..858f232 100644 --- a/tests/fixtures/optimizer/qualify_columns.sql +++ b/tests/fixtures/optimizer/qualify_columns.sql @@ -72,6 +72,9 @@ SELECT x.a AS a, x.b AS b FROM x AS x ORDER BY a; SELECT a FROM x ORDER BY b; SELECT x.a AS a FROM x AS x ORDER BY x.b; +SELECT SUM(a) AS a FROM x ORDER BY SUM(a); +SELECT SUM(x.a) AS a FROM x AS x ORDER BY SUM(x.a); + # dialect: bigquery SELECT ROW_NUMBER() OVER (PARTITION BY a ORDER BY b) AS row_num FROM x QUALIFY row_num = 1; SELECT ROW_NUMBER() OVER (PARTITION BY x.a ORDER BY x.b) AS row_num FROM x AS x QUALIFY row_num = 1; diff --git a/tests/helpers.py b/tests/helpers.py index 2d200f6..dabaf1c 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -53,6 +53,8 @@ def string_to_bool(string): return string and string.lower() in ("true", "1") +SKIP_INTEGRATION = string_to_bool(os.environ.get("SKIP_INTEGRATION", "0").lower()) + TPCH_SCHEMA = { "lineitem": { "l_orderkey": "uint64", diff --git a/tests/test_executor.py b/tests/test_executor.py index c5841d3..ef1a706 100644 --- a/tests/test_executor.py +++ b/tests/test_executor.py @@ -7,11 +7,17 @@ from pandas.testing import assert_frame_equal from sqlglot import exp, parse_one from sqlglot.executor import execute from sqlglot.executor.python import Python -from tests.helpers import FIXTURES_DIR, TPCH_SCHEMA, load_sql_fixture_pairs +from tests.helpers import ( + FIXTURES_DIR, + SKIP_INTEGRATION, + TPCH_SCHEMA, + load_sql_fixture_pairs, +) DIR = FIXTURES_DIR + "/optimizer/tpc-h/" +@unittest.skipIf(SKIP_INTEGRATION, "Skipping Integration Tests since `SKIP_INTEGRATION` is set") class TestExecutor(unittest.TestCase): @classmethod def setUpClass(cls): diff --git a/tests/test_expressions.py b/tests/test_expressions.py index 9ad2bf5..79b4ee5 100644 --- a/tests/test_expressions.py +++ b/tests/test_expressions.py @@ -123,13 +123,16 @@ class TestExpressions(unittest.TestCase): self.assertEqual(exp.table_name(parse_one("a.b.c", into=exp.Table)), "a.b.c") self.assertEqual(exp.table_name("a.b.c"), "a.b.c") + def test_table(self): + self.assertEqual(exp.table_("a", alias="b"), parse_one("select * from a b").find(exp.Table)) + def test_replace_tables(self): self.assertEqual( exp.replace_tables( - parse_one("select * from a join b join c.a join d.a join e.a"), + parse_one("select * from a AS a join b join c.a join d.a join e.a"), {"a": "a1", "b": "b.a", "c.a": "c.a2", "d.a": "d2"}, ).sql(), - 'SELECT * FROM "a1" JOIN "b"."a" JOIN "c"."a2" JOIN "d2" JOIN e.a', + "SELECT * FROM a1 AS a JOIN b.a JOIN c.a2 JOIN d2 JOIN e.a", ) def test_named_selects(self): @@ -495,11 +498,15 @@ class TestExpressions(unittest.TestCase): self.assertEqual(exp.convert(value).sql(), expected) def test_annotation_alias(self): - expression = parse_one("SELECT a, b AS B, c #comment, d AS D #another_comment FROM foo") + sql = "SELECT a, b AS B, c # comment, d AS D # another_comment FROM foo" + expression = parse_one(sql) self.assertEqual( [e.alias_or_name for e in expression.expressions], ["a", "B", "c", "D"], ) + self.assertEqual(expression.sql(), sql) + self.assertEqual(expression.expressions[2].name, "comment") + self.assertEqual(expression.sql(annotations=False), "SELECT a, b AS B, c, d AS D") def test_to_table(self): table_only = exp.to_table("table_name") @@ -514,6 +521,18 @@ class TestExpressions(unittest.TestCase): self.assertEqual(catalog_db_and_table.name, "table_name") self.assertEqual(catalog_db_and_table.args.get("db"), exp.to_identifier("db")) self.assertEqual(catalog_db_and_table.args.get("catalog"), exp.to_identifier("catalog")) + with self.assertRaises(ValueError): + exp.to_table(1) + + def test_to_column(self): + column_only = exp.to_column("column_name") + self.assertEqual(column_only.name, "column_name") + self.assertIsNone(column_only.args.get("table")) + table_and_column = exp.to_column("table_name.column_name") + self.assertEqual(table_and_column.name, "column_name") + self.assertEqual(table_and_column.args.get("table"), exp.to_identifier("table_name")) + with self.assertRaises(ValueError): + exp.to_column(1) def test_union(self): expression = parse_one("SELECT cola, colb UNION SELECT colx, coly") diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index a67e9db..3b5990f 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -5,11 +5,11 @@ import duckdb from pandas.testing import assert_frame_equal import sqlglot -from sqlglot import exp, optimizer, parse_one, table +from sqlglot import exp, optimizer, parse_one from sqlglot.errors import OptimizeError from sqlglot.optimizer.annotate_types import annotate_types -from sqlglot.optimizer.schema import MappingSchema, ensure_schema from sqlglot.optimizer.scope import build_scope, traverse_scope, walk_in_scope +from sqlglot.schema import MappingSchema from tests.helpers import ( TPCH_SCHEMA, load_sql_fixture_pairs, @@ -29,19 +29,19 @@ class TestOptimizer(unittest.TestCase): CREATE TABLE x (a INT, b INT); CREATE TABLE y (b INT, c INT); CREATE TABLE z (b INT, c INT); - + INSERT INTO x VALUES (1, 1); INSERT INTO x VALUES (2, 2); INSERT INTO x VALUES (2, 2); INSERT INTO x VALUES (3, 3); INSERT INTO x VALUES (null, null); - + INSERT INTO y VALUES (2, 2); INSERT INTO y VALUES (2, 2); INSERT INTO y VALUES (3, 3); INSERT INTO y VALUES (4, 4); INSERT INTO y VALUES (null, null); - + INSERT INTO y VALUES (3, 3); INSERT INTO y VALUES (3, 3); INSERT INTO y VALUES (4, 4); @@ -80,8 +80,8 @@ class TestOptimizer(unittest.TestCase): with self.subTest(title): self.assertEqual( - optimized.sql(pretty=pretty, dialect=dialect), expected, + optimized.sql(pretty=pretty, dialect=dialect), ) should_execute = meta.get("execute") @@ -223,85 +223,6 @@ class TestOptimizer(unittest.TestCase): def test_tpch(self): self.check_file("tpc-h/tpc-h", optimizer.optimize, schema=TPCH_SCHEMA, pretty=True) - def test_schema(self): - schema = ensure_schema( - { - "x": { - "a": "uint64", - } - } - ) - self.assertEqual( - schema.column_names( - table( - "x", - ) - ), - ["a"], - ) - with self.assertRaises(ValueError): - schema.column_names(table("x", db="db", catalog="c")) - with self.assertRaises(ValueError): - schema.column_names(table("x", db="db")) - with self.assertRaises(ValueError): - schema.column_names(table("x2")) - - schema = ensure_schema( - { - "db": { - "x": { - "a": "uint64", - } - } - } - ) - self.assertEqual(schema.column_names(table("x", db="db")), ["a"]) - with self.assertRaises(ValueError): - schema.column_names(table("x", db="db", catalog="c")) - with self.assertRaises(ValueError): - schema.column_names(table("x")) - with self.assertRaises(ValueError): - schema.column_names(table("x", db="db2")) - with self.assertRaises(ValueError): - schema.column_names(table("x2", db="db")) - - schema = ensure_schema( - { - "c": { - "db": { - "x": { - "a": "uint64", - } - } - } - } - ) - self.assertEqual(schema.column_names(table("x", db="db", catalog="c")), ["a"]) - with self.assertRaises(ValueError): - schema.column_names(table("x", db="db")) - with self.assertRaises(ValueError): - schema.column_names(table("x")) - with self.assertRaises(ValueError): - schema.column_names(table("x", db="db", catalog="c2")) - with self.assertRaises(ValueError): - schema.column_names(table("x", db="db2")) - with self.assertRaises(ValueError): - schema.column_names(table("x2", db="db")) - - schema = ensure_schema( - MappingSchema( - { - "x": { - "a": "uint64", - } - } - ) - ) - self.assertEqual(schema.column_names(table("x")), ["a"]) - - with self.assertRaises(OptimizeError): - ensure_schema({}) - def test_file_schema(self): expression = parse_one( """ @@ -327,6 +248,8 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') SELECT x.b FROM x ), r AS ( SELECT y.b FROM y + ), z as ( + SELECT cola, colb FROM (VALUES(1, 'test')) AS tab(cola, colb) ) SELECT r.b, @@ -340,19 +263,23 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') """ expression = parse_one(sql) for scopes in traverse_scope(expression), list(build_scope(expression).traverse()): - self.assertEqual(len(scopes), 5) + self.assertEqual(len(scopes), 7) self.assertEqual(scopes[0].expression.sql(), "SELECT x.b FROM x") self.assertEqual(scopes[1].expression.sql(), "SELECT y.b FROM y") - self.assertEqual(scopes[2].expression.sql(), "SELECT y.c AS b FROM y") - self.assertEqual(scopes[3].expression.sql(), "SELECT MAX(x.a) FROM x WHERE x.b = s.b") - self.assertEqual(scopes[4].expression.sql(), parse_one(sql).sql()) - - self.assertEqual(set(scopes[4].sources), {"q", "r", "s"}) - self.assertEqual(len(scopes[4].columns), 6) - self.assertEqual(set(c.table for c in scopes[4].columns), {"r", "s"}) - self.assertEqual(scopes[4].source_columns("q"), []) - self.assertEqual(len(scopes[4].source_columns("r")), 2) - self.assertEqual(set(c.table for c in scopes[4].source_columns("r")), {"r"}) + self.assertEqual(scopes[2].expression.sql(), "(VALUES (1, 'test')) AS tab(cola, colb)") + self.assertEqual( + scopes[3].expression.sql(), "SELECT cola, colb FROM (VALUES (1, 'test')) AS tab(cola, colb)" + ) + self.assertEqual(scopes[4].expression.sql(), "SELECT y.c AS b FROM y") + self.assertEqual(scopes[5].expression.sql(), "SELECT MAX(x.a) FROM x WHERE x.b = s.b") + self.assertEqual(scopes[6].expression.sql(), parse_one(sql).sql()) + + self.assertEqual(set(scopes[6].sources), {"q", "z", "r", "s"}) + self.assertEqual(len(scopes[6].columns), 6) + self.assertEqual(set(c.table for c in scopes[6].columns), {"r", "s"}) + self.assertEqual(scopes[6].source_columns("q"), []) + self.assertEqual(len(scopes[6].source_columns("r")), 2) + self.assertEqual(set(c.table for c in scopes[6].source_columns("r")), {"r"}) self.assertEqual({c.sql() for c in scopes[-1].find_all(exp.Column)}, {"r.b", "s.b"}) self.assertEqual(scopes[-1].find(exp.Column).sql(), "r.b") diff --git a/tests/test_parser.py b/tests/test_parser.py index 4e86516..9afeae6 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -81,7 +81,7 @@ class TestParser(unittest.TestCase): self.assertIsInstance(ignore.expression(exp.Hint, y=""), exp.Hint) self.assertIsInstance(ignore.expression(exp.Hint), exp.Hint) - default = Parser() + default = Parser(error_level=ErrorLevel.RAISE) self.assertIsInstance(default.expression(exp.Hint, expressions=[""]), exp.Hint) default.expression(exp.Hint, y="") default.expression(exp.Hint) @@ -139,12 +139,12 @@ class TestParser(unittest.TestCase): ) assert expression.expressions[0].name == "annotation1" - assert expression.expressions[1].name == "annotation2:testing " + assert expression.expressions[1].name == "annotation2:testing" assert expression.expressions[2].name == "test#annotation" assert expression.expressions[3].name == "annotation3" assert expression.expressions[4].name == "annotation4" assert expression.expressions[5].name == "" - assert expression.expressions[6].name == " space" + assert expression.expressions[6].name == "space" def test_pretty_config_override(self): self.assertEqual(parse_one("SELECT col FROM x").sql(), "SELECT col FROM x") diff --git a/tests/test_schema.py b/tests/test_schema.py new file mode 100644 index 0000000..bab97d8 --- /dev/null +++ b/tests/test_schema.py @@ -0,0 +1,290 @@ +import unittest + +from sqlglot import table +from sqlglot.dataframe.sql import types as df_types +from sqlglot.schema import MappingSchema, ensure_schema + + +class TestSchema(unittest.TestCase): + def test_schema(self): + schema = ensure_schema( + { + "x": { + "a": "uint64", + } + } + ) + self.assertEqual( + schema.column_names( + table( + "x", + ) + ), + ["a"], + ) + with self.assertRaises(ValueError): + schema.column_names(table("x", db="db", catalog="c")) + with self.assertRaises(ValueError): + schema.column_names(table("x", db="db")) + with self.assertRaises(ValueError): + schema.column_names(table("x2")) + + with self.assertRaises(ValueError): + schema.add_table(table("y", db="db"), {"b": "string"}) + with self.assertRaises(ValueError): + schema.add_table(table("y", db="db", catalog="c"), {"b": "string"}) + + schema.add_table(table("y"), {"b": "string"}) + schema_with_y = { + "x": { + "a": "uint64", + }, + "y": { + "b": "string", + }, + } + self.assertEqual(schema.schema, schema_with_y) + + new_schema = schema.copy() + new_schema.add_table(table("z"), {"c": "string"}) + self.assertEqual(schema.schema, schema_with_y) + self.assertEqual( + new_schema.schema, + { + "x": { + "a": "uint64", + }, + "y": { + "b": "string", + }, + "z": { + "c": "string", + }, + }, + ) + schema.add_table(table("m"), {"d": "string"}) + schema.add_table(table("n"), {"e": "string"}) + schema_with_m_n = { + "x": { + "a": "uint64", + }, + "y": { + "b": "string", + }, + "m": { + "d": "string", + }, + "n": { + "e": "string", + }, + } + self.assertEqual(schema.schema, schema_with_m_n) + new_schema = schema.copy() + new_schema.add_table(table("o"), {"f": "string"}) + new_schema.add_table(table("p"), {"g": "string"}) + self.assertEqual(schema.schema, schema_with_m_n) + self.assertEqual( + new_schema.schema, + { + "x": { + "a": "uint64", + }, + "y": { + "b": "string", + }, + "m": { + "d": "string", + }, + "n": { + "e": "string", + }, + "o": { + "f": "string", + }, + "p": { + "g": "string", + }, + }, + ) + + schema = ensure_schema( + { + "db": { + "x": { + "a": "uint64", + } + } + } + ) + self.assertEqual(schema.column_names(table("x", db="db")), ["a"]) + with self.assertRaises(ValueError): + schema.column_names(table("x", db="db", catalog="c")) + with self.assertRaises(ValueError): + schema.column_names(table("x")) + with self.assertRaises(ValueError): + schema.column_names(table("x", db="db2")) + with self.assertRaises(ValueError): + schema.column_names(table("x2", db="db")) + + with self.assertRaises(ValueError): + schema.add_table(table("y"), {"b": "string"}) + with self.assertRaises(ValueError): + schema.add_table(table("y", db="db", catalog="c"), {"b": "string"}) + + schema.add_table(table("y", db="db"), {"b": "string"}) + self.assertEqual( + schema.schema, + { + "db": { + "x": { + "a": "uint64", + }, + "y": { + "b": "string", + }, + } + }, + ) + + schema = ensure_schema( + { + "c": { + "db": { + "x": { + "a": "uint64", + } + } + } + } + ) + self.assertEqual(schema.column_names(table("x", db="db", catalog="c")), ["a"]) + with self.assertRaises(ValueError): + schema.column_names(table("x", db="db")) + with self.assertRaises(ValueError): + schema.column_names(table("x")) + with self.assertRaises(ValueError): + schema.column_names(table("x", db="db", catalog="c2")) + with self.assertRaises(ValueError): + schema.column_names(table("x", db="db2")) + with self.assertRaises(ValueError): + schema.column_names(table("x2", db="db")) + + with self.assertRaises(ValueError): + schema.add_table(table("x"), {"b": "string"}) + with self.assertRaises(ValueError): + schema.add_table(table("x", db="db"), {"b": "string"}) + + schema.add_table(table("y", db="db", catalog="c"), {"a": "string", "b": "int"}) + self.assertEqual( + schema.schema, + { + "c": { + "db": { + "x": { + "a": "uint64", + }, + "y": { + "a": "string", + "b": "int", + }, + } + } + }, + ) + schema.add_table(table("z", db="db2", catalog="c"), {"c": "string", "d": "int"}) + self.assertEqual( + schema.schema, + { + "c": { + "db": { + "x": { + "a": "uint64", + }, + "y": { + "a": "string", + "b": "int", + }, + }, + "db2": { + "z": { + "c": "string", + "d": "int", + } + }, + } + }, + ) + schema.add_table(table("m", db="db2", catalog="c2"), {"e": "string", "f": "int"}) + self.assertEqual( + schema.schema, + { + "c": { + "db": { + "x": { + "a": "uint64", + }, + "y": { + "a": "string", + "b": "int", + }, + }, + "db2": { + "z": { + "c": "string", + "d": "int", + } + }, + }, + "c2": { + "db2": { + "m": { + "e": "string", + "f": "int", + } + } + }, + }, + ) + + schema = ensure_schema( + { + "x": { + "a": "uint64", + } + } + ) + self.assertEqual(schema.column_names(table("x")), ["a"]) + + schema = MappingSchema() + schema.add_table(table("x"), {"a": "string"}) + self.assertEqual( + schema.schema, + { + "x": { + "a": "string", + } + }, + ) + schema.add_table(table("y"), df_types.StructType([df_types.StructField("b", df_types.StringType())])) + self.assertEqual( + schema.schema, + { + "x": { + "a": "string", + }, + "y": { + "b": "string", + }, + }, + ) + + def test_schema_add_table_with_and_without_mapping(self): + schema = MappingSchema() + schema.add_table("test") + self.assertEqual(schema.column_names("test"), []) + schema.add_table("test", {"x": "string"}) + self.assertEqual(schema.column_names("test"), ["x"]) + schema.add_table("test", {"x": "string", "y": "int"}) + self.assertEqual(schema.column_names("test"), ["x", "y"]) + schema.add_table("test") + self.assertEqual(schema.column_names("test"), ["x", "y"]) -- cgit v1.2.3