summaryrefslogtreecommitdiffstats
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/dataframe/__init__.py0
-rw-r--r--tests/dataframe/integration/__init__.py0
-rw-r--r--tests/dataframe/integration/dataframe_validator.py174
-rw-r--r--tests/dataframe/integration/test_dataframe.py1281
-rw-r--r--tests/dataframe/integration/test_grouped_data.py71
-rw-r--r--tests/dataframe/integration/test_session.py43
-rw-r--r--tests/dataframe/unit/__init__.py0
-rw-r--r--tests/dataframe/unit/dataframe_sql_validator.py28
-rw-r--r--tests/dataframe/unit/dataframe_test_base.py23
-rw-r--r--tests/dataframe/unit/test_column.py174
-rw-r--r--tests/dataframe/unit/test_dataframe.py43
-rw-r--r--tests/dataframe/unit/test_dataframe_writer.py95
-rw-r--r--tests/dataframe/unit/test_functions.py1632
-rw-r--r--tests/dataframe/unit/test_session.py101
-rw-r--r--tests/dataframe/unit/test_session_case_sensitivity.py87
-rw-r--r--tests/dataframe/unit/test_types.py73
-rw-r--r--tests/dataframe/unit/test_window.py75
-rw-r--r--tests/dialects/test_databricks.py55
-rw-r--r--tests/dialects/test_dialect.py8
-rw-r--r--tests/dialects/test_doris.py14
-rw-r--r--tests/dialects/test_mysql.py8
-rw-r--r--tests/dialects/test_postgres.py20
-rw-r--r--tests/dialects/test_snowflake.py11
-rw-r--r--tests/dialects/test_teradata.py2
-rw-r--r--tests/fixtures/identity.sql3
25 files changed, 86 insertions, 3935 deletions
diff --git a/tests/dataframe/__init__.py b/tests/dataframe/__init__.py
deleted file mode 100644
index e69de29..0000000
--- a/tests/dataframe/__init__.py
+++ /dev/null
diff --git a/tests/dataframe/integration/__init__.py b/tests/dataframe/integration/__init__.py
deleted file mode 100644
index e69de29..0000000
--- a/tests/dataframe/integration/__init__.py
+++ /dev/null
diff --git a/tests/dataframe/integration/dataframe_validator.py b/tests/dataframe/integration/dataframe_validator.py
deleted file mode 100644
index 22d4982..0000000
--- a/tests/dataframe/integration/dataframe_validator.py
+++ /dev/null
@@ -1,174 +0,0 @@
-import typing as t
-import unittest
-import warnings
-
-import sqlglot
-from tests.helpers import SKIP_INTEGRATION
-
-if t.TYPE_CHECKING:
- from pyspark.sql import DataFrame as SparkDataFrame
-
-
-@unittest.skipIf(SKIP_INTEGRATION, "Skipping Integration Tests since `SKIP_INTEGRATION` is set")
-class DataFrameValidator(unittest.TestCase):
- spark = None
- sqlglot = None
- df_employee = None
- df_store = None
- df_district = None
- spark_employee_schema = None
- sqlglot_employee_schema = None
- spark_store_schema = None
- sqlglot_store_schema = None
- spark_district_schema = None
- sqlglot_district_schema = None
-
- @classmethod
- def setUpClass(cls):
- from pyspark import SparkConf
- from pyspark.sql import SparkSession, types
-
- from sqlglot.dataframe.sql import types as sqlglotSparkTypes
- from sqlglot.dataframe.sql.session import SparkSession as SqlglotSparkSession
-
- # This is for test `test_branching_root_dataframes`
- config = SparkConf().setAll([("spark.sql.analyzer.failAmbiguousSelfJoin", "false")])
- cls.spark = (
- SparkSession.builder.master("local[*]")
- .appName("Unit-tests")
- .config(conf=config)
- .getOrCreate()
- )
- cls.spark.sparkContext.setLogLevel("ERROR")
- cls.sqlglot = SqlglotSparkSession()
- cls.spark_employee_schema = types.StructType(
- [
- types.StructField("employee_id", types.IntegerType(), False),
- types.StructField("fname", types.StringType(), False),
- types.StructField("lname", types.StringType(), False),
- types.StructField("age", types.IntegerType(), False),
- types.StructField("store_id", types.IntegerType(), False),
- ]
- )
- cls.sqlglot_employee_schema = sqlglotSparkTypes.StructType(
- [
- sqlglotSparkTypes.StructField(
- "employee_id", sqlglotSparkTypes.IntegerType(), False
- ),
- sqlglotSparkTypes.StructField("fname", sqlglotSparkTypes.StringType(), False),
- sqlglotSparkTypes.StructField("lname", sqlglotSparkTypes.StringType(), False),
- sqlglotSparkTypes.StructField("age", sqlglotSparkTypes.IntegerType(), False),
- sqlglotSparkTypes.StructField("store_id", sqlglotSparkTypes.IntegerType(), False),
- ]
- )
- employee_data = [
- (1, "Jack", "Shephard", 37, 1),
- (2, "John", "Locke", 65, 1),
- (3, "Kate", "Austen", 37, 2),
- (4, "Claire", "Littleton", 27, 2),
- (5, "Hugo", "Reyes", 29, 100),
- ]
- cls.df_employee = cls.spark.createDataFrame(
- data=employee_data, schema=cls.spark_employee_schema
- )
- cls.dfs_employee = cls.sqlglot.createDataFrame(
- data=employee_data, schema=cls.sqlglot_employee_schema
- )
- cls.df_employee.createOrReplaceTempView("employee")
-
- cls.spark_store_schema = types.StructType(
- [
- types.StructField("store_id", types.IntegerType(), False),
- types.StructField("store_name", types.StringType(), False),
- types.StructField("district_id", types.IntegerType(), False),
- types.StructField("num_sales", types.IntegerType(), False),
- ]
- )
- cls.sqlglot_store_schema = sqlglotSparkTypes.StructType(
- [
- sqlglotSparkTypes.StructField("store_id", sqlglotSparkTypes.IntegerType(), False),
- sqlglotSparkTypes.StructField("store_name", sqlglotSparkTypes.StringType(), False),
- sqlglotSparkTypes.StructField(
- "district_id", sqlglotSparkTypes.IntegerType(), False
- ),
- sqlglotSparkTypes.StructField("num_sales", sqlglotSparkTypes.IntegerType(), False),
- ]
- )
- store_data = [
- (1, "Hydra", 1, 37),
- (2, "Arrow", 2, 2000),
- ]
- cls.df_store = cls.spark.createDataFrame(data=store_data, schema=cls.spark_store_schema)
- cls.dfs_store = cls.sqlglot.createDataFrame(
- data=store_data, schema=cls.sqlglot_store_schema
- )
- cls.df_store.createOrReplaceTempView("store")
-
- cls.spark_district_schema = types.StructType(
- [
- types.StructField("district_id", types.IntegerType(), False),
- types.StructField("district_name", types.StringType(), False),
- types.StructField("manager_name", types.StringType(), False),
- ]
- )
- cls.sqlglot_district_schema = sqlglotSparkTypes.StructType(
- [
- sqlglotSparkTypes.StructField(
- "district_id", sqlglotSparkTypes.IntegerType(), False
- ),
- sqlglotSparkTypes.StructField(
- "district_name", sqlglotSparkTypes.StringType(), False
- ),
- sqlglotSparkTypes.StructField(
- "manager_name", sqlglotSparkTypes.StringType(), False
- ),
- ]
- )
- district_data = [
- (1, "Temple", "Dogen"),
- (2, "Lighthouse", "Jacob"),
- ]
- cls.df_district = cls.spark.createDataFrame(
- data=district_data, schema=cls.spark_district_schema
- )
- cls.dfs_district = cls.sqlglot.createDataFrame(
- data=district_data, schema=cls.sqlglot_district_schema
- )
- cls.df_district.createOrReplaceTempView("district")
- sqlglot.schema.add_table("employee", cls.sqlglot_employee_schema, dialect="spark")
- sqlglot.schema.add_table("store", cls.sqlglot_store_schema, dialect="spark")
- sqlglot.schema.add_table("district", cls.sqlglot_district_schema, dialect="spark")
-
- def setUp(self) -> None:
- warnings.filterwarnings("ignore", category=ResourceWarning)
- self.df_spark_store = self.df_store.alias("df_store") # type: ignore
- self.df_spark_employee = self.df_employee.alias("df_employee") # type: ignore
- self.df_spark_district = self.df_district.alias("df_district") # type: ignore
- self.df_sqlglot_store = self.dfs_store.alias("store") # type: ignore
- self.df_sqlglot_employee = self.dfs_employee.alias("employee") # type: ignore
- self.df_sqlglot_district = self.dfs_district.alias("district") # type: ignore
-
- def compare_spark_with_sqlglot(
- self, df_spark, df_sqlglot, no_empty=True, skip_schema_compare=False
- ) -> t.Tuple["SparkDataFrame", "SparkDataFrame"]:
- def compare_schemas(schema_1, schema_2):
- for schema in [schema_1, schema_2]:
- for struct_field in schema.fields:
- struct_field.metadata = {}
- self.assertEqual(schema_1, schema_2)
-
- for statement in df_sqlglot.sql():
- actual_df_sqlglot = self.spark.sql(statement) # type: ignore
- df_sqlglot_results = actual_df_sqlglot.collect()
- df_spark_results = df_spark.collect()
- if not skip_schema_compare:
- compare_schemas(df_spark.schema, actual_df_sqlglot.schema)
- self.assertEqual(df_spark_results, df_sqlglot_results)
- if no_empty:
- self.assertNotEqual(len(df_spark_results), 0)
- self.assertNotEqual(len(df_sqlglot_results), 0)
- return df_spark, actual_df_sqlglot
-
- @classmethod
- def get_explain_plan(cls, df: "SparkDataFrame", mode: str = "extended") -> str:
- return df._sc._jvm.PythonSQLUtils.explainString(df._jdf.queryExecution(), mode) # type: ignore
diff --git a/tests/dataframe/integration/test_dataframe.py b/tests/dataframe/integration/test_dataframe.py
deleted file mode 100644
index 702c6ee..0000000
--- a/tests/dataframe/integration/test_dataframe.py
+++ /dev/null
@@ -1,1281 +0,0 @@
-from pyspark.sql import functions as F
-
-from sqlglot.dataframe.sql import functions as SF
-from tests.dataframe.integration.dataframe_validator import DataFrameValidator
-
-
-class TestDataframeFunc(DataFrameValidator):
- def test_simple_select(self):
- df_employee = self.df_spark_employee.select(F.col("employee_id"))
- dfs_employee = self.df_sqlglot_employee.select(SF.col("employee_id"))
- self.compare_spark_with_sqlglot(df_employee, dfs_employee)
-
- def test_simple_select_from_table(self):
- df = self.df_spark_employee
- dfs = self.sqlglot.read.table("employee")
- self.compare_spark_with_sqlglot(df, dfs)
-
- def test_simple_select_df_attribute(self):
- df_employee = self.df_spark_employee.select(self.df_spark_employee.employee_id)
- dfs_employee = self.df_sqlglot_employee.select(self.df_sqlglot_employee.employee_id)
- self.compare_spark_with_sqlglot(df_employee, dfs_employee)
-
- def test_simple_select_df_dict(self):
- df_employee = self.df_spark_employee.select(self.df_spark_employee["employee_id"])
- dfs_employee = self.df_sqlglot_employee.select(self.df_sqlglot_employee["employee_id"])
- self.compare_spark_with_sqlglot(df_employee, dfs_employee)
-
- def test_multiple_selects(self):
- df_employee = self.df_spark_employee.select(
- self.df_spark_employee["employee_id"], F.col("fname"), self.df_spark_employee.lname
- )
- dfs_employee = self.df_sqlglot_employee.select(
- self.df_sqlglot_employee["employee_id"], SF.col("fname"), self.df_sqlglot_employee.lname
- )
- self.compare_spark_with_sqlglot(df_employee, dfs_employee)
-
- def test_alias_no_op(self):
- df_employee = self.df_spark_employee.alias("df_employee")
- dfs_employee = self.df_sqlglot_employee.alias("dfs_employee")
- self.compare_spark_with_sqlglot(df_employee, dfs_employee)
-
- def test_alias_with_select(self):
- df_employee = self.df_spark_employee.alias("df_employee").select(
- self.df_spark_employee["employee_id"],
- F.col("df_employee.fname"),
- self.df_spark_employee.lname,
- )
- dfs_employee = self.df_sqlglot_employee.alias("dfs_employee").select(
- self.df_sqlglot_employee["employee_id"],
- SF.col("dfs_employee.fname"),
- self.df_sqlglot_employee.lname,
- )
- self.compare_spark_with_sqlglot(df_employee, dfs_employee)
-
- def test_case_when_otherwise(self):
- df = self.df_spark_employee.select(
- F.when(
- (F.col("age") >= F.lit(40)) & (F.col("age") <= F.lit(60)),
- F.lit("between 40 and 60"),
- )
- .when(F.col("age") < F.lit(40), "less than 40")
- .otherwise("greater than 60")
- )
-
- dfs = self.df_sqlglot_employee.select(
- SF.when(
- (SF.col("age") >= SF.lit(40)) & (SF.col("age") <= SF.lit(60)),
- SF.lit("between 40 and 60"),
- )
- .when(SF.col("age") < SF.lit(40), "less than 40")
- .otherwise("greater than 60")
- )
-
- self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
-
- def test_case_when_no_otherwise(self):
- df = self.df_spark_employee.select(
- F.when(
- (F.col("age") >= F.lit(40)) & (F.col("age") <= F.lit(60)),
- F.lit("between 40 and 60"),
- ).when(F.col("age") < F.lit(40), "less than 40")
- )
-
- dfs = self.df_sqlglot_employee.select(
- SF.when(
- (SF.col("age") >= SF.lit(40)) & (SF.col("age") <= SF.lit(60)),
- SF.lit("between 40 and 60"),
- ).when(SF.col("age") < SF.lit(40), "less than 40")
- )
-
- self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
-
- def test_where_clause_single(self):
- df_employee = self.df_spark_employee.where(F.col("age") == F.lit(37))
- dfs_employee = self.df_sqlglot_employee.where(SF.col("age") == SF.lit(37))
- self.compare_spark_with_sqlglot(df_employee, dfs_employee)
-
- def test_where_clause_multiple_and(self):
- df_employee = self.df_spark_employee.where(
- (F.col("age") == F.lit(37)) & (F.col("fname") == F.lit("Jack"))
- )
- dfs_employee = self.df_sqlglot_employee.where(
- (SF.col("age") == SF.lit(37)) & (SF.col("fname") == SF.lit("Jack"))
- )
- self.compare_spark_with_sqlglot(df_employee, dfs_employee)
-
- def test_where_many_and(self):
- df_employee = self.df_spark_employee.where(
- (F.col("age") == F.lit(37))
- & (F.col("fname") == F.lit("Jack"))
- & (F.col("lname") == F.lit("Shephard"))
- & (F.col("employee_id") == F.lit(1))
- )
- dfs_employee = self.df_sqlglot_employee.where(
- (SF.col("age") == SF.lit(37))
- & (SF.col("fname") == SF.lit("Jack"))
- & (SF.col("lname") == SF.lit("Shephard"))
- & (SF.col("employee_id") == SF.lit(1))
- )
- self.compare_spark_with_sqlglot(df_employee, dfs_employee)
-
- def test_where_clause_multiple_or(self):
- df_employee = self.df_spark_employee.where(
- (F.col("age") == F.lit(37)) | (F.col("fname") == F.lit("Kate"))
- )
- dfs_employee = self.df_sqlglot_employee.where(
- (SF.col("age") == SF.lit(37)) | (SF.col("fname") == SF.lit("Kate"))
- )
- self.compare_spark_with_sqlglot(df_employee, dfs_employee)
-
- def test_where_many_or(self):
- df_employee = self.df_spark_employee.where(
- (F.col("age") == F.lit(37))
- | (F.col("fname") == F.lit("Kate"))
- | (F.col("lname") == F.lit("Littleton"))
- | (F.col("employee_id") == F.lit(2))
- )
- dfs_employee = self.df_sqlglot_employee.where(
- (SF.col("age") == SF.lit(37))
- | (SF.col("fname") == SF.lit("Kate"))
- | (SF.col("lname") == SF.lit("Littleton"))
- | (SF.col("employee_id") == SF.lit(2))
- )
- self.compare_spark_with_sqlglot(df_employee, dfs_employee)
-
- def test_where_mixed_and_or(self):
- df_employee = self.df_spark_employee.where(
- ((F.col("age") == F.lit(65)) & (F.col("fname") == F.lit("John")))
- | ((F.col("lname") == F.lit("Shephard")) & (F.col("age") == F.lit(37)))
- )
- dfs_employee = self.df_sqlglot_employee.where(
- ((SF.col("age") == SF.lit(65)) & (SF.col("fname") == SF.lit("John")))
- | ((SF.col("lname") == SF.lit("Shephard")) & (SF.col("age") == SF.lit(37)))
- )
- self.compare_spark_with_sqlglot(df_employee, dfs_employee)
-
- def test_where_multiple_chained(self):
- df_employee = self.df_spark_employee.where(F.col("age") == F.lit(37)).where(
- self.df_spark_employee.fname == F.lit("Jack")
- )
- dfs_employee = self.df_sqlglot_employee.where(SF.col("age") == SF.lit(37)).where(
- self.df_sqlglot_employee.fname == SF.lit("Jack")
- )
- self.compare_spark_with_sqlglot(df_employee, dfs_employee)
-
- def test_operators(self):
- df_employee = self.df_spark_employee.where(self.df_spark_employee["age"] < F.lit(50))
- dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] < SF.lit(50))
- self.compare_spark_with_sqlglot(df_employee, dfs_employee)
-
- df_employee = self.df_spark_employee.where(self.df_spark_employee["age"] <= F.lit(37))
- dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] <= SF.lit(37))
- self.compare_spark_with_sqlglot(df_employee, dfs_employee)
-
- df_employee = self.df_spark_employee.where(self.df_spark_employee["age"] > F.lit(50))
- dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] > SF.lit(50))
- self.compare_spark_with_sqlglot(df_employee, dfs_employee)
-
- df_employee = self.df_spark_employee.where(self.df_spark_employee["age"] >= F.lit(37))
- dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] >= SF.lit(37))
- self.compare_spark_with_sqlglot(df_employee, dfs_employee)
-
- df_employee = self.df_spark_employee.where(self.df_spark_employee["age"] != F.lit(50))
- dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] != SF.lit(50))
- self.compare_spark_with_sqlglot(df_employee, dfs_employee)
-
- df_employee = self.df_spark_employee.where(self.df_spark_employee["age"] == F.lit(37))
- dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] == SF.lit(37))
- self.compare_spark_with_sqlglot(df_employee, dfs_employee)
-
- df_employee = self.df_spark_employee.where(
- self.df_spark_employee["age"] % F.lit(5) == F.lit(0)
- )
- dfs_employee = self.df_sqlglot_employee.where(
- self.df_sqlglot_employee["age"] % SF.lit(5) == SF.lit(0)
- )
- self.compare_spark_with_sqlglot(df_employee, dfs_employee)
-
- df_employee = self.df_spark_employee.where(
- self.df_spark_employee["age"] + F.lit(5) > F.lit(28)
- )
- dfs_employee = self.df_sqlglot_employee.where(
- self.df_sqlglot_employee["age"] + SF.lit(5) > SF.lit(28)
- )
- self.compare_spark_with_sqlglot(df_employee, dfs_employee)
-
- df_employee = self.df_spark_employee.where(
- self.df_spark_employee["age"] - F.lit(5) > F.lit(28)
- )
- dfs_employee = self.df_sqlglot_employee.where(
- self.df_sqlglot_employee["age"] - SF.lit(5) > SF.lit(28)
- )
- self.compare_spark_with_sqlglot(df_employee, dfs_employee)
-
- df_employee = self.df_spark_employee.where(
- self.df_spark_employee["age"] * F.lit(0.5) == self.df_spark_employee["age"] / F.lit(2)
- )
- dfs_employee = self.df_sqlglot_employee.where(
- self.df_sqlglot_employee["age"] * SF.lit(0.5)
- == self.df_sqlglot_employee["age"] / SF.lit(2)
- )
- self.compare_spark_with_sqlglot(df_employee, dfs_employee)
-
- def test_join_inner(self):
- df_joined = self.df_spark_employee.join(
- self.df_spark_store, on=["store_id"], how="inner"
- ).select(
- self.df_spark_employee.employee_id,
- self.df_spark_employee["fname"],
- F.col("lname"),
- F.col("age"),
- F.col("store_id"),
- self.df_spark_store.store_name,
- self.df_spark_store["num_sales"],
- )
- dfs_joined = self.df_sqlglot_employee.join(
- self.df_sqlglot_store, on=["store_id"], how="inner"
- ).select(
- self.df_sqlglot_employee.employee_id,
- self.df_sqlglot_employee["fname"],
- SF.col("lname"),
- SF.col("age"),
- SF.col("store_id"),
- self.df_sqlglot_store.store_name,
- self.df_sqlglot_store["num_sales"],
- )
- self.compare_spark_with_sqlglot(df_joined, dfs_joined)
-
- def test_join_inner_no_select(self):
- df_joined = self.df_spark_employee.select(
- F.col("store_id"), F.col("fname"), F.col("lname")
- ).join(
- self.df_spark_store.select(F.col("store_id"), F.col("store_name")),
- on=["store_id"],
- how="inner",
- )
- dfs_joined = self.df_sqlglot_employee.select(
- SF.col("store_id"), SF.col("fname"), SF.col("lname")
- ).join(
- self.df_sqlglot_store.select(SF.col("store_id"), SF.col("store_name")),
- on=["store_id"],
- how="inner",
- )
- self.compare_spark_with_sqlglot(df_joined, dfs_joined)
-
- def test_join_inner_equality_single(self):
- df_joined = self.df_spark_employee.join(
- self.df_spark_store,
- on=self.df_spark_employee.store_id == self.df_spark_store.store_id,
- how="inner",
- ).select(
- self.df_spark_employee.employee_id,
- self.df_spark_employee["fname"],
- F.col("lname"),
- F.col("age"),
- self.df_spark_employee.store_id,
- self.df_spark_store.store_name,
- self.df_spark_store["num_sales"],
- F.lit("literal_value"),
- )
- dfs_joined = self.df_sqlglot_employee.join(
- self.df_sqlglot_store,
- on=self.df_sqlglot_employee.store_id == self.df_sqlglot_store.store_id,
- how="inner",
- ).select(
- self.df_sqlglot_employee.employee_id,
- self.df_sqlglot_employee["fname"],
- SF.col("lname"),
- SF.col("age"),
- self.df_sqlglot_employee.store_id,
- self.df_sqlglot_store.store_name,
- self.df_sqlglot_store["num_sales"],
- SF.lit("literal_value"),
- )
- self.compare_spark_with_sqlglot(df_joined, dfs_joined)
-
- def test_join_inner_equality_multiple(self):
- df_joined = self.df_spark_employee.join(
- self.df_spark_store,
- on=[
- self.df_spark_employee.store_id == self.df_spark_store.store_id,
- self.df_spark_employee.age == self.df_spark_store.num_sales,
- ],
- how="inner",
- ).select(
- self.df_spark_employee.employee_id,
- self.df_spark_employee["fname"],
- F.col("lname"),
- F.col("age"),
- self.df_spark_employee.store_id,
- self.df_spark_store.store_name,
- self.df_spark_store["num_sales"],
- )
- dfs_joined = self.df_sqlglot_employee.join(
- self.df_sqlglot_store,
- on=[
- self.df_sqlglot_employee.store_id == self.df_sqlglot_store.store_id,
- self.df_sqlglot_employee.age == self.df_sqlglot_store.num_sales,
- ],
- how="inner",
- ).select(
- self.df_sqlglot_employee.employee_id,
- self.df_sqlglot_employee["fname"],
- SF.col("lname"),
- SF.col("age"),
- self.df_sqlglot_employee.store_id,
- self.df_sqlglot_store.store_name,
- self.df_sqlglot_store["num_sales"],
- )
- self.compare_spark_with_sqlglot(df_joined, dfs_joined)
-
- def test_join_inner_equality_multiple_bitwise_and(self):
- df_joined = self.df_spark_employee.join(
- self.df_spark_store,
- on=(self.df_spark_store.store_id == self.df_spark_employee.store_id)
- & (self.df_spark_store.num_sales == self.df_spark_employee.age),
- how="inner",
- ).select(
- self.df_spark_employee.employee_id,
- self.df_spark_employee["fname"],
- F.col("lname"),
- F.col("age"),
- self.df_spark_employee.store_id,
- self.df_spark_store.store_name,
- self.df_spark_store["num_sales"],
- )
- dfs_joined = self.df_sqlglot_employee.join(
- self.df_sqlglot_store,
- on=(self.df_sqlglot_store.store_id == self.df_sqlglot_employee.store_id)
- & (self.df_sqlglot_store.num_sales == self.df_sqlglot_employee.age),
- how="inner",
- ).select(
- self.df_sqlglot_employee.employee_id,
- self.df_sqlglot_employee["fname"],
- SF.col("lname"),
- SF.col("age"),
- self.df_sqlglot_employee.store_id,
- self.df_sqlglot_store.store_name,
- self.df_sqlglot_store["num_sales"],
- )
- self.compare_spark_with_sqlglot(df_joined, dfs_joined)
-
- def test_join_left_outer(self):
- df_joined = (
- self.df_spark_employee.join(self.df_spark_store, on=["store_id"], how="left_outer")
- .select(
- self.df_spark_employee.employee_id,
- self.df_spark_employee["fname"],
- F.col("lname"),
- F.col("age"),
- F.col("store_id"),
- self.df_spark_store.store_name,
- self.df_spark_store["num_sales"],
- )
- .orderBy(F.col("employee_id"))
- )
- dfs_joined = (
- self.df_sqlglot_employee.join(self.df_sqlglot_store, on=["store_id"], how="left_outer")
- .select(
- self.df_sqlglot_employee.employee_id,
- self.df_sqlglot_employee["fname"],
- SF.col("lname"),
- SF.col("age"),
- SF.col("store_id"),
- self.df_sqlglot_store.store_name,
- self.df_sqlglot_store["num_sales"],
- )
- .orderBy(SF.col("employee_id"))
- )
- self.compare_spark_with_sqlglot(df_joined, dfs_joined)
-
- def test_join_full_outer(self):
- df_joined = self.df_spark_employee.join(
- self.df_spark_store, on=["store_id"], how="full_outer"
- ).select(
- self.df_spark_employee.employee_id,
- self.df_spark_employee["fname"],
- F.col("lname"),
- F.col("age"),
- F.col("store_id"),
- self.df_spark_store.store_name,
- self.df_spark_store["num_sales"],
- )
- dfs_joined = self.df_sqlglot_employee.join(
- self.df_sqlglot_store, on=["store_id"], how="full_outer"
- ).select(
- self.df_sqlglot_employee.employee_id,
- self.df_sqlglot_employee["fname"],
- SF.col("lname"),
- SF.col("age"),
- SF.col("store_id"),
- self.df_sqlglot_store.store_name,
- self.df_sqlglot_store["num_sales"],
- )
- self.compare_spark_with_sqlglot(df_joined, dfs_joined)
-
- def test_triple_join(self):
- df = (
- self.df_employee.join(
- self.df_store, on=self.df_employee.employee_id == self.df_store.store_id
- )
- .join(self.df_district, on=self.df_store.store_id == self.df_district.district_id)
- .select(
- self.df_employee.employee_id,
- self.df_store.store_id,
- self.df_district.district_id,
- self.df_employee.fname,
- self.df_store.store_name,
- self.df_district.district_name,
- )
- )
- dfs = (
- self.dfs_employee.join(
- self.dfs_store, on=self.dfs_employee.employee_id == self.dfs_store.store_id
- )
- .join(self.dfs_district, on=self.dfs_store.store_id == self.dfs_district.district_id)
- .select(
- self.dfs_employee.employee_id,
- self.dfs_store.store_id,
- self.dfs_district.district_id,
- self.dfs_employee.fname,
- self.dfs_store.store_name,
- self.dfs_district.district_name,
- )
- )
- self.compare_spark_with_sqlglot(df, dfs)
-
- def test_triple_join_no_select(self):
- df = (
- self.df_employee.join(
- self.df_store,
- on=self.df_employee["employee_id"] == self.df_store["store_id"],
- how="left",
- )
- .join(
- self.df_district,
- on=self.df_store["store_id"] == self.df_district["district_id"],
- how="left",
- )
- .orderBy(F.col("employee_id"))
- )
- dfs = (
- self.dfs_employee.join(
- self.dfs_store,
- on=self.dfs_employee["employee_id"] == self.dfs_store["store_id"],
- how="left",
- )
- .join(
- self.dfs_district,
- on=self.dfs_store["store_id"] == self.dfs_district["district_id"],
- how="left",
- )
- .orderBy(SF.col("employee_id"))
- )
- self.compare_spark_with_sqlglot(df, dfs)
-
- def test_triple_joins_filter(self):
- df = (
- self.df_employee.join(
- self.df_store,
- on=self.df_employee["employee_id"] == self.df_store["store_id"],
- how="left",
- ).join(
- self.df_district,
- on=self.df_store["store_id"] == self.df_district["district_id"],
- how="left",
- )
- ).filter(F.coalesce(self.df_store["num_sales"], F.lit(0)) > 100)
- dfs = (
- self.dfs_employee.join(
- self.dfs_store,
- on=self.dfs_employee["employee_id"] == self.dfs_store["store_id"],
- how="left",
- ).join(
- self.dfs_district,
- on=self.dfs_store["store_id"] == self.dfs_district["district_id"],
- how="left",
- )
- ).filter(SF.coalesce(self.dfs_store["num_sales"], SF.lit(0)) > 100)
- self.compare_spark_with_sqlglot(df, dfs)
-
- def test_triple_join_column_name_only(self):
- df = (
- self.df_employee.join(
- self.df_store,
- on=self.df_employee["employee_id"] == self.df_store["store_id"],
- how="left",
- )
- .join(self.df_district, on="district_id", how="left")
- .orderBy(F.col("employee_id"))
- )
- dfs = (
- self.dfs_employee.join(
- self.dfs_store,
- on=self.dfs_employee["employee_id"] == self.dfs_store["store_id"],
- how="left",
- )
- .join(self.dfs_district, on="district_id", how="left")
- .orderBy(SF.col("employee_id"))
- )
- self.compare_spark_with_sqlglot(df, dfs)
-
- def test_join_select_and_select_start(self):
- df = self.df_spark_employee.select(
- F.col("fname"), F.col("lname"), F.col("age"), F.col("store_id")
- ).join(self.df_spark_store, "store_id", "inner")
-
- dfs = self.df_sqlglot_employee.select(
- SF.col("fname"), SF.col("lname"), SF.col("age"), SF.col("store_id")
- ).join(self.df_sqlglot_store, "store_id", "inner")
-
- self.compare_spark_with_sqlglot(df, dfs)
-
- def test_branching_root_dataframes(self):
- """
- Test a pattern that has non-intuitive behavior in spark
-
- Scenario: You do a self-join in a dataframe using an original dataframe and then a modified version
- of it. You then reference the columns by the dataframe name instead of the column function.
- Spark will use the root dataframe's column in the result.
- """
- df_hydra_employees_only = self.df_spark_employee.where(F.col("store_id") == F.lit(1))
- df_joined = (
- self.df_spark_employee.where(F.col("store_id") == F.lit(2))
- .alias("df_arrow_employees_only")
- .join(
- df_hydra_employees_only.alias("df_hydra_employees_only"),
- on=["store_id"],
- how="full_outer",
- )
- .select(
- self.df_spark_employee.fname,
- F.col("df_arrow_employees_only.fname"),
- df_hydra_employees_only.fname,
- F.col("df_hydra_employees_only.fname"),
- )
- )
-
- dfs_hydra_employees_only = self.df_sqlglot_employee.where(SF.col("store_id") == SF.lit(1))
- dfs_joined = (
- self.df_sqlglot_employee.where(SF.col("store_id") == SF.lit(2))
- .alias("dfs_arrow_employees_only")
- .join(
- dfs_hydra_employees_only.alias("dfs_hydra_employees_only"),
- on=["store_id"],
- how="full_outer",
- )
- .select(
- self.df_sqlglot_employee.fname,
- SF.col("dfs_arrow_employees_only.fname"),
- dfs_hydra_employees_only.fname,
- SF.col("dfs_hydra_employees_only.fname"),
- )
- )
- self.compare_spark_with_sqlglot(df_joined, dfs_joined)
-
- def test_basic_union(self):
- df_unioned = self.df_spark_employee.select(F.col("employee_id"), F.col("age")).union(
- self.df_spark_store.select(F.col("store_id"), F.col("num_sales"))
- )
-
- dfs_unioned = self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("age")).union(
- self.df_sqlglot_store.select(SF.col("store_id"), SF.col("num_sales"))
- )
- self.compare_spark_with_sqlglot(df_unioned, dfs_unioned)
-
- def test_union_with_join(self):
- df_joined = self.df_spark_employee.join(
- self.df_spark_store,
- on="store_id",
- how="inner",
- )
- df_unioned = df_joined.select(F.col("store_id"), F.col("store_name")).union(
- self.df_spark_district.select(F.col("district_id"), F.col("district_name"))
- )
-
- dfs_joined = self.df_sqlglot_employee.join(
- self.df_sqlglot_store,
- on="store_id",
- how="inner",
- )
- dfs_unioned = dfs_joined.select(SF.col("store_id"), SF.col("store_name")).union(
- self.df_sqlglot_district.select(SF.col("district_id"), SF.col("district_name"))
- )
-
- self.compare_spark_with_sqlglot(df_unioned, dfs_unioned)
-
- def test_double_union_all(self):
- df_unioned = (
- self.df_spark_employee.select(F.col("employee_id"), F.col("fname"))
- .unionAll(self.df_spark_store.select(F.col("store_id"), F.col("store_name")))
- .unionAll(self.df_spark_district.select(F.col("district_id"), F.col("district_name")))
- )
-
- dfs_unioned = (
- self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("fname"))
- .unionAll(self.df_sqlglot_store.select(SF.col("store_id"), SF.col("store_name")))
- .unionAll(
- self.df_sqlglot_district.select(SF.col("district_id"), SF.col("district_name"))
- )
- )
-
- self.compare_spark_with_sqlglot(df_unioned, dfs_unioned)
-
- def test_union_by_name(self):
- df = self.df_spark_employee.select(
- F.col("employee_id"), F.col("fname"), F.col("lname")
- ).unionByName(
- self.df_spark_store.select(
- F.col("store_name").alias("lname"),
- F.col("store_id").alias("employee_id"),
- F.col("store_name").alias("fname"),
- )
- )
-
- dfs = self.df_sqlglot_employee.select(
- SF.col("employee_id"), SF.col("fname"), SF.col("lname")
- ).unionByName(
- self.df_sqlglot_store.select(
- SF.col("store_name").alias("lname"),
- SF.col("store_id").alias("employee_id"),
- SF.col("store_name").alias("fname"),
- )
- )
-
- self.compare_spark_with_sqlglot(df, dfs)
-
- def test_union_by_name_allow_missing(self):
- df = self.df_spark_employee.select(
- F.col("age"), F.col("employee_id"), F.col("fname"), F.col("lname")
- ).unionByName(
- self.df_spark_store.select(
- F.col("store_name").alias("lname"),
- F.col("store_id").alias("employee_id"),
- F.col("store_name").alias("fname"),
- F.col("num_sales"),
- ),
- allowMissingColumns=True,
- )
-
- dfs = self.df_sqlglot_employee.select(
- SF.col("age"), SF.col("employee_id"), SF.col("fname"), SF.col("lname")
- ).unionByName(
- self.df_sqlglot_store.select(
- SF.col("store_name").alias("lname"),
- SF.col("store_id").alias("employee_id"),
- SF.col("store_name").alias("fname"),
- SF.col("num_sales"),
- ),
- allowMissingColumns=True,
- )
-
- self.compare_spark_with_sqlglot(df, dfs)
-
- def test_order_by_default(self):
- df = (
- self.df_spark_store.groupBy(F.col("district_id"))
- .agg(F.min("num_sales"))
- .orderBy(F.col("district_id"))
- )
-
- dfs = (
- self.df_sqlglot_store.groupBy(SF.col("district_id"))
- .agg(SF.min("num_sales"))
- .orderBy(SF.col("district_id"))
- )
-
- self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
-
- def test_order_by_array_bool(self):
- df = (
- self.df_spark_store.groupBy(F.col("district_id"))
- .agg(F.min("num_sales").alias("total_sales"))
- .orderBy(F.col("total_sales"), F.col("district_id"), ascending=[1, 0])
- )
-
- dfs = (
- self.df_sqlglot_store.groupBy(SF.col("district_id"))
- .agg(SF.min("num_sales").alias("total_sales"))
- .orderBy(SF.col("total_sales"), SF.col("district_id"), ascending=[1, 0])
- )
-
- self.compare_spark_with_sqlglot(df, dfs)
-
- def test_order_by_single_bool(self):
- df = (
- self.df_spark_store.groupBy(F.col("district_id"))
- .agg(F.min("num_sales").alias("total_sales"))
- .orderBy(F.col("total_sales"), F.col("district_id"), ascending=False)
- )
-
- dfs = (
- self.df_sqlglot_store.groupBy(SF.col("district_id"))
- .agg(SF.min("num_sales").alias("total_sales"))
- .orderBy(SF.col("total_sales"), SF.col("district_id"), ascending=False)
- )
-
- self.compare_spark_with_sqlglot(df, dfs)
-
- def test_order_by_column_sort_method(self):
- df = (
- self.df_spark_store.groupBy(F.col("district_id"))
- .agg(F.min("num_sales").alias("total_sales"))
- .orderBy(F.col("total_sales").asc(), F.col("district_id").desc())
- )
-
- dfs = (
- self.df_sqlglot_store.groupBy(SF.col("district_id"))
- .agg(SF.min("num_sales").alias("total_sales"))
- .orderBy(SF.col("total_sales").asc(), SF.col("district_id").desc())
- )
-
- self.compare_spark_with_sqlglot(df, dfs)
-
- def test_order_by_column_sort_method_nulls_last(self):
- df = (
- self.df_spark_store.groupBy(F.col("district_id"))
- .agg(F.min("num_sales").alias("total_sales"))
- .orderBy(
- F.when(F.col("district_id") == F.lit(2), F.col("district_id")).asc_nulls_last()
- )
- )
-
- dfs = (
- self.df_sqlglot_store.groupBy(SF.col("district_id"))
- .agg(SF.min("num_sales").alias("total_sales"))
- .orderBy(
- SF.when(SF.col("district_id") == SF.lit(2), SF.col("district_id")).asc_nulls_last()
- )
- )
-
- self.compare_spark_with_sqlglot(df, dfs)
-
- def test_order_by_column_sort_method_nulls_first(self):
- df = (
- self.df_spark_store.groupBy(F.col("district_id"))
- .agg(F.min("num_sales").alias("total_sales"))
- .orderBy(
- F.when(F.col("district_id") == F.lit(1), F.col("district_id")).desc_nulls_first()
- )
- )
-
- dfs = (
- self.df_sqlglot_store.groupBy(SF.col("district_id"))
- .agg(SF.min("num_sales").alias("total_sales"))
- .orderBy(
- SF.when(
- SF.col("district_id") == SF.lit(1), SF.col("district_id")
- ).desc_nulls_first()
- )
- )
-
- self.compare_spark_with_sqlglot(df, dfs)
-
- def test_intersect(self):
- df_employee_duplicate = self.df_spark_employee.select(
- F.col("employee_id"), F.col("store_id")
- ).union(self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")))
-
- df_store_duplicate = self.df_spark_store.select(
- F.col("store_id"), F.col("district_id")
- ).union(self.df_spark_store.select(F.col("store_id"), F.col("district_id")))
-
- df = df_employee_duplicate.intersect(df_store_duplicate)
-
- dfs_employee_duplicate = self.df_sqlglot_employee.select(
- SF.col("employee_id"), SF.col("store_id")
- ).union(self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")))
-
- dfs_store_duplicate = self.df_sqlglot_store.select(
- SF.col("store_id"), SF.col("district_id")
- ).union(self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")))
-
- dfs = dfs_employee_duplicate.intersect(dfs_store_duplicate)
-
- self.compare_spark_with_sqlglot(df, dfs)
-
- def test_intersect_all(self):
- df_employee_duplicate = self.df_spark_employee.select(
- F.col("employee_id"), F.col("store_id")
- ).union(self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")))
-
- df_store_duplicate = self.df_spark_store.select(
- F.col("store_id"), F.col("district_id")
- ).union(self.df_spark_store.select(F.col("store_id"), F.col("district_id")))
-
- df = df_employee_duplicate.intersectAll(df_store_duplicate)
-
- dfs_employee_duplicate = self.df_sqlglot_employee.select(
- SF.col("employee_id"), SF.col("store_id")
- ).union(self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")))
-
- dfs_store_duplicate = self.df_sqlglot_store.select(
- SF.col("store_id"), SF.col("district_id")
- ).union(self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")))
-
- dfs = dfs_employee_duplicate.intersectAll(dfs_store_duplicate)
-
- self.compare_spark_with_sqlglot(df, dfs)
-
- def test_except_all(self):
- df_employee_duplicate = self.df_spark_employee.select(
- F.col("employee_id"), F.col("store_id")
- ).union(self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")))
-
- df_store_duplicate = self.df_spark_store.select(
- F.col("store_id"), F.col("district_id")
- ).union(self.df_spark_store.select(F.col("store_id"), F.col("district_id")))
-
- df = df_employee_duplicate.exceptAll(df_store_duplicate)
-
- dfs_employee_duplicate = self.df_sqlglot_employee.select(
- SF.col("employee_id"), SF.col("store_id")
- ).union(self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")))
-
- dfs_store_duplicate = self.df_sqlglot_store.select(
- SF.col("store_id"), SF.col("district_id")
- ).union(self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")))
-
- dfs = dfs_employee_duplicate.exceptAll(dfs_store_duplicate)
-
- self.compare_spark_with_sqlglot(df, dfs)
-
- def test_distinct(self):
- df = self.df_spark_employee.select(F.col("age")).distinct()
-
- dfs = self.df_sqlglot_employee.select(SF.col("age")).distinct()
-
- self.compare_spark_with_sqlglot(df, dfs)
-
- def test_union_distinct(self):
- df_unioned = (
- self.df_spark_employee.select(F.col("employee_id"), F.col("age"))
- .union(self.df_spark_employee.select(F.col("employee_id"), F.col("age")))
- .distinct()
- )
-
- dfs_unioned = (
- self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("age"))
- .union(self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("age")))
- .distinct()
- )
- self.compare_spark_with_sqlglot(df_unioned, dfs_unioned)
-
- def test_drop_duplicates_no_subset(self):
- df = self.df_spark_employee.select("age").dropDuplicates()
- dfs = self.df_sqlglot_employee.select("age").dropDuplicates()
- self.compare_spark_with_sqlglot(df, dfs)
-
- def test_drop_duplicates_subset(self):
- df = self.df_spark_employee.dropDuplicates(["age"])
- dfs = self.df_sqlglot_employee.dropDuplicates(["age"])
- self.compare_spark_with_sqlglot(df, dfs)
-
- def test_drop_na_default(self):
- df = self.df_spark_employee.select(
- F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")
- ).dropna()
-
- dfs = self.df_sqlglot_employee.select(
- SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age")
- ).dropna()
-
- self.compare_spark_with_sqlglot(df, dfs)
-
- def test_dropna_how(self):
- df = self.df_spark_employee.select(
- F.lit(None), F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")
- ).dropna(how="all")
-
- dfs = self.df_sqlglot_employee.select(
- SF.lit(None), SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age")
- ).dropna(how="all")
-
- self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
-
- def test_dropna_thresh(self):
- df = self.df_spark_employee.select(
- F.lit(None), F.lit(1), F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")
- ).dropna(how="any", thresh=2)
-
- dfs = self.df_sqlglot_employee.select(
- SF.lit(None),
- SF.lit(1),
- SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age"),
- ).dropna(how="any", thresh=2)
-
- self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
-
- def test_dropna_subset(self):
- df = self.df_spark_employee.select(
- F.lit(None), F.lit(1), F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")
- ).dropna(thresh=1, subset="the_age")
-
- dfs = self.df_sqlglot_employee.select(
- SF.lit(None),
- SF.lit(1),
- SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age"),
- ).dropna(thresh=1, subset="the_age")
-
- self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
-
- def test_dropna_na_function(self):
- df = self.df_spark_employee.select(
- F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")
- ).na.drop()
-
- dfs = self.df_sqlglot_employee.select(
- SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age")
- ).na.drop()
-
- self.compare_spark_with_sqlglot(df, dfs)
-
- def test_fillna_default(self):
- df = self.df_spark_employee.select(
- F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")
- ).fillna(100)
-
- dfs = self.df_sqlglot_employee.select(
- SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age")
- ).fillna(100)
-
- self.compare_spark_with_sqlglot(df, dfs)
-
- def test_fillna_dict_replacement(self):
- df = self.df_spark_employee.select(
- F.col("fname"),
- F.when(F.col("lname").startswith("L"), F.col("lname")).alias("l_lname"),
- F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age"),
- ).fillna({"fname": "Jacob", "l_lname": "NOT_LNAME"})
-
- dfs = self.df_sqlglot_employee.select(
- SF.col("fname"),
- SF.when(SF.col("lname").startswith("L"), SF.col("lname")).alias("l_lname"),
- SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age"),
- ).fillna({"fname": "Jacob", "l_lname": "NOT_LNAME"})
-
- # For some reason the sqlglot results sets a column as nullable when it doesn't need to
- # This seems to be a nuance in how spark dataframe from sql works so we can ignore
- self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
-
- def test_fillna_na_func(self):
- df = self.df_spark_employee.select(
- F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")
- ).na.fill(100)
-
- dfs = self.df_sqlglot_employee.select(
- SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age")
- ).na.fill(100)
-
- self.compare_spark_with_sqlglot(df, dfs)
-
- def test_replace_basic(self):
- df = self.df_spark_employee.select(F.col("age"), F.lit(37).alias("test_col")).replace(
- to_replace=37, value=100
- )
-
- dfs = self.df_sqlglot_employee.select(SF.col("age"), SF.lit(37).alias("test_col")).replace(
- to_replace=37, value=100
- )
-
- self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
-
- def test_replace_basic_subset(self):
- df = self.df_spark_employee.select(F.col("age"), F.lit(37).alias("test_col")).replace(
- to_replace=37, value=100, subset="age"
- )
-
- dfs = self.df_sqlglot_employee.select(SF.col("age"), SF.lit(37).alias("test_col")).replace(
- to_replace=37, value=100, subset="age"
- )
-
- self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
-
- def test_replace_mapping(self):
- df = self.df_spark_employee.select(F.col("age"), F.lit(37).alias("test_col")).replace(
- {37: 100}
- )
-
- dfs = self.df_sqlglot_employee.select(SF.col("age"), SF.lit(37).alias("test_col")).replace(
- {37: 100}
- )
-
- self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
-
- def test_replace_mapping_subset(self):
- df = self.df_spark_employee.select(
- F.col("age"), F.lit(37).alias("test_col"), F.lit(50).alias("test_col_2")
- ).replace({37: 100, 50: 1}, subset=["age", "test_col_2"])
-
- dfs = self.df_sqlglot_employee.select(
- SF.col("age"), SF.lit(37).alias("test_col"), SF.lit(50).alias("test_col_2")
- ).replace({37: 100, 50: 1}, subset=["age", "test_col_2"])
-
- self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
-
- def test_replace_na_func_basic(self):
- df = self.df_spark_employee.select(F.col("age"), F.lit(37).alias("test_col")).na.replace(
- to_replace=37, value=100
- )
-
- dfs = self.df_sqlglot_employee.select(
- SF.col("age"), SF.lit(37).alias("test_col")
- ).na.replace(to_replace=37, value=100)
-
- self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
-
- def test_with_column(self):
- df = self.df_spark_employee.withColumn("test", F.col("age"))
-
- dfs = self.df_sqlglot_employee.withColumn("test", SF.col("age"))
-
- self.compare_spark_with_sqlglot(df, dfs)
-
- def test_with_column_existing_name(self):
- df = self.df_spark_employee.withColumn("fname", F.lit("blah"))
-
- dfs = self.df_sqlglot_employee.withColumn("fname", SF.lit("blah"))
-
- self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
-
- def test_with_column_renamed(self):
- df = self.df_spark_employee.withColumnRenamed("fname", "first_name")
-
- dfs = self.df_sqlglot_employee.withColumnRenamed("fname", "first_name")
-
- self.compare_spark_with_sqlglot(df, dfs)
-
- def test_with_column_renamed_double(self):
- df = self.df_spark_employee.select(F.col("fname").alias("first_name")).withColumnRenamed(
- "first_name", "first_name_again"
- )
-
- dfs = self.df_sqlglot_employee.select(
- SF.col("fname").alias("first_name")
- ).withColumnRenamed("first_name", "first_name_again")
-
- self.compare_spark_with_sqlglot(df, dfs)
-
- def test_drop_column_single(self):
- df = self.df_spark_employee.select(F.col("fname"), F.col("lname"), F.col("age")).drop("age")
-
- dfs = self.df_sqlglot_employee.select(SF.col("fname"), SF.col("lname"), SF.col("age")).drop(
- "age"
- )
-
- self.compare_spark_with_sqlglot(df, dfs)
-
- def test_drop_column_reference_join(self):
- df_spark_employee_cols = self.df_spark_employee.select(
- F.col("fname"), F.col("lname"), F.col("age"), F.col("store_id")
- )
- df_spark_store_cols = self.df_spark_store.select(F.col("store_id"), F.col("store_name"))
- df = df_spark_employee_cols.join(df_spark_store_cols, on="store_id", how="inner").drop(
- df_spark_employee_cols.age,
- )
-
- df_sqlglot_employee_cols = self.df_sqlglot_employee.select(
- SF.col("fname"), SF.col("lname"), SF.col("age"), SF.col("store_id")
- )
- df_sqlglot_store_cols = self.df_sqlglot_store.select(
- SF.col("store_id"), SF.col("store_name")
- )
- dfs = df_sqlglot_employee_cols.join(df_sqlglot_store_cols, on="store_id", how="inner").drop(
- df_sqlglot_employee_cols.age,
- )
-
- self.compare_spark_with_sqlglot(df, dfs)
-
- def test_limit(self):
- df = self.df_spark_employee.limit(1)
-
- dfs = self.df_sqlglot_employee.limit(1)
-
- self.compare_spark_with_sqlglot(df, dfs)
-
- def test_hint_broadcast_alias(self):
- df_joined = self.df_spark_employee.join(
- self.df_spark_store.alias("store").hint("broadcast", "store"),
- on=self.df_spark_employee.store_id == self.df_spark_store.store_id,
- how="inner",
- ).select(
- self.df_spark_employee.employee_id,
- self.df_spark_employee["fname"],
- F.col("lname"),
- F.col("age"),
- self.df_spark_employee.store_id,
- self.df_spark_store.store_name,
- self.df_spark_store["num_sales"],
- )
- dfs_joined = self.df_sqlglot_employee.join(
- self.df_sqlglot_store.alias("store").hint("broadcast", "store"),
- on=self.df_sqlglot_employee.store_id == self.df_sqlglot_store.store_id,
- how="inner",
- ).select(
- self.df_sqlglot_employee.employee_id,
- self.df_sqlglot_employee["fname"],
- SF.col("lname"),
- SF.col("age"),
- self.df_sqlglot_employee.store_id,
- self.df_sqlglot_store.store_name,
- self.df_sqlglot_store["num_sales"],
- )
- df, dfs = self.compare_spark_with_sqlglot(df_joined, dfs_joined)
- self.assertIn("ResolvedHint (strategy=broadcast)", self.get_explain_plan(df))
- self.assertIn("ResolvedHint (strategy=broadcast)", self.get_explain_plan(dfs))
-
- def test_hint_broadcast_no_alias(self):
- df_joined = self.df_spark_employee.join(
- self.df_spark_store.hint("broadcast"),
- on=self.df_spark_employee.store_id == self.df_spark_store.store_id,
- how="inner",
- ).select(
- self.df_spark_employee.employee_id,
- self.df_spark_employee["fname"],
- F.col("lname"),
- F.col("age"),
- self.df_spark_employee.store_id,
- self.df_spark_store.store_name,
- self.df_spark_store["num_sales"],
- )
- dfs_joined = self.df_sqlglot_employee.join(
- self.df_sqlglot_store.hint("broadcast"),
- on=self.df_sqlglot_employee.store_id == self.df_sqlglot_store.store_id,
- how="inner",
- ).select(
- self.df_sqlglot_employee.employee_id,
- self.df_sqlglot_employee["fname"],
- SF.col("lname"),
- SF.col("age"),
- self.df_sqlglot_employee.store_id,
- self.df_sqlglot_store.store_name,
- self.df_sqlglot_store["num_sales"],
- )
- df, dfs = self.compare_spark_with_sqlglot(df_joined, dfs_joined)
- self.assertIn("ResolvedHint (strategy=broadcast)", self.get_explain_plan(df))
- self.assertIn("ResolvedHint (strategy=broadcast)", self.get_explain_plan(dfs))
- self.assertEqual(
- "'UnresolvedHint BROADCAST, ['a2]", self.get_explain_plan(dfs).split("\n")[1]
- )
-
- def test_broadcast_func(self):
- df_joined = self.df_spark_employee.join(
- F.broadcast(self.df_spark_store),
- on=self.df_spark_employee.store_id == self.df_spark_store.store_id,
- how="inner",
- ).select(
- self.df_spark_employee.employee_id,
- self.df_spark_employee["fname"],
- F.col("lname"),
- F.col("age"),
- self.df_spark_employee.store_id,
- self.df_spark_store.store_name,
- self.df_spark_store["num_sales"],
- )
- dfs_joined = self.df_sqlglot_employee.join(
- SF.broadcast(self.df_sqlglot_store),
- on=self.df_sqlglot_employee.store_id == self.df_sqlglot_store.store_id,
- how="inner",
- ).select(
- self.df_sqlglot_employee.employee_id,
- self.df_sqlglot_employee["fname"],
- SF.col("lname"),
- SF.col("age"),
- self.df_sqlglot_employee.store_id,
- self.df_sqlglot_store.store_name,
- self.df_sqlglot_store["num_sales"],
- )
- df, dfs = self.compare_spark_with_sqlglot(df_joined, dfs_joined)
- self.assertIn("ResolvedHint (strategy=broadcast)", self.get_explain_plan(df))
- self.assertIn("ResolvedHint (strategy=broadcast)", self.get_explain_plan(dfs))
- self.assertEqual(
- "'UnresolvedHint BROADCAST, ['a2]", self.get_explain_plan(dfs).split("\n")[1]
- )
-
- def test_repartition_by_num(self):
- """
- The results are different when doing the repartition on a table created using VALUES in SQL.
- So I just use the views instead for these tests
- """
- df = self.df_spark_employee.repartition(63)
-
- dfs = self.sqlglot.read.table("employee").repartition(63)
- df, dfs = self.compare_spark_with_sqlglot(df, dfs)
- spark_num_partitions = df.rdd.getNumPartitions()
- sqlglot_num_partitions = dfs.rdd.getNumPartitions()
- self.assertEqual(spark_num_partitions, 63)
- self.assertEqual(spark_num_partitions, sqlglot_num_partitions)
-
- def test_repartition_name_only(self):
- """
- We use the view here to help ensure the explain plans are similar enough to compare
- """
- df = self.df_spark_employee.repartition("age")
-
- dfs = self.sqlglot.read.table("employee").repartition("age")
- df, dfs = self.compare_spark_with_sqlglot(df, dfs)
- self.assertIn("RepartitionByExpression [age", self.get_explain_plan(df))
- self.assertIn("RepartitionByExpression [age", self.get_explain_plan(dfs))
-
- def test_repartition_num_and_multiple_names(self):
- """
- We use the view here to help ensure the explain plans are similar enough to compare
- """
- df = self.df_spark_employee.repartition(53, "age", "fname")
-
- dfs = self.sqlglot.read.table("employee").repartition(53, "age", "fname")
- df, dfs = self.compare_spark_with_sqlglot(df, dfs)
- spark_num_partitions = df.rdd.getNumPartitions()
- sqlglot_num_partitions = dfs.rdd.getNumPartitions()
- self.assertEqual(spark_num_partitions, 53)
- self.assertEqual(spark_num_partitions, sqlglot_num_partitions)
- self.assertIn("RepartitionByExpression [age#3, fname#1], 53", self.get_explain_plan(df))
- self.assertIn("RepartitionByExpression [age#3, fname#1], 53", self.get_explain_plan(dfs))
-
- def test_coalesce(self):
- df = self.df_spark_employee.coalesce(1)
- dfs = self.df_sqlglot_employee.coalesce(1)
- df, dfs = self.compare_spark_with_sqlglot(df, dfs)
- spark_num_partitions = df.rdd.getNumPartitions()
- sqlglot_num_partitions = dfs.rdd.getNumPartitions()
- self.assertEqual(spark_num_partitions, 1)
- self.assertEqual(spark_num_partitions, sqlglot_num_partitions)
-
- def test_cache_select(self):
- df_employee = (
- self.df_spark_employee.groupBy("store_id")
- .agg(F.countDistinct("employee_id").alias("num_employees"))
- .cache()
- )
- df_joined = df_employee.join(self.df_spark_store, on="store_id").select(
- self.df_spark_store.store_id, df_employee.num_employees
- )
- dfs_employee = (
- self.df_sqlglot_employee.groupBy("store_id")
- .agg(SF.countDistinct("employee_id").alias("num_employees"))
- .cache()
- )
- dfs_joined = dfs_employee.join(self.df_sqlglot_store, on="store_id").select(
- self.df_sqlglot_store.store_id, dfs_employee.num_employees
- )
- self.compare_spark_with_sqlglot(df_joined, dfs_joined)
-
- def test_persist_select(self):
- df_employee = (
- self.df_spark_employee.groupBy("store_id")
- .agg(F.countDistinct("employee_id").alias("num_employees"))
- .persist()
- )
- df_joined = df_employee.join(self.df_spark_store, on="store_id").select(
- self.df_spark_store.store_id, df_employee.num_employees
- )
- dfs_employee = (
- self.df_sqlglot_employee.groupBy("store_id")
- .agg(SF.countDistinct("employee_id").alias("num_employees"))
- .persist()
- )
- dfs_joined = dfs_employee.join(self.df_sqlglot_store, on="store_id").select(
- self.df_sqlglot_store.store_id, dfs_employee.num_employees
- )
- self.compare_spark_with_sqlglot(df_joined, dfs_joined)
diff --git a/tests/dataframe/integration/test_grouped_data.py b/tests/dataframe/integration/test_grouped_data.py
deleted file mode 100644
index 2768dda..0000000
--- a/tests/dataframe/integration/test_grouped_data.py
+++ /dev/null
@@ -1,71 +0,0 @@
-from pyspark.sql import functions as F
-
-from sqlglot.dataframe.sql import functions as SF
-from tests.dataframe.integration.dataframe_validator import DataFrameValidator
-
-
-class TestDataframeFunc(DataFrameValidator):
- def test_group_by(self):
- df_employee = self.df_spark_employee.groupBy(self.df_spark_employee.age).agg(
- F.min(self.df_spark_employee.employee_id)
- )
- dfs_employee = self.df_sqlglot_employee.groupBy(self.df_sqlglot_employee.age).agg(
- SF.min(self.df_sqlglot_employee.employee_id)
- )
- self.compare_spark_with_sqlglot(df_employee, dfs_employee, skip_schema_compare=True)
-
- def test_group_by_where_non_aggregate(self):
- df_employee = (
- self.df_spark_employee.groupBy(self.df_spark_employee.age)
- .agg(F.min(self.df_spark_employee.employee_id).alias("min_employee_id"))
- .where(F.col("age") > F.lit(50))
- )
- dfs_employee = (
- self.df_sqlglot_employee.groupBy(self.df_sqlglot_employee.age)
- .agg(SF.min(self.df_sqlglot_employee.employee_id).alias("min_employee_id"))
- .where(SF.col("age") > SF.lit(50))
- )
- self.compare_spark_with_sqlglot(df_employee, dfs_employee)
-
- def test_group_by_where_aggregate_like_having(self):
- df_employee = (
- self.df_spark_employee.groupBy(self.df_spark_employee.age)
- .agg(F.min(self.df_spark_employee.employee_id).alias("min_employee_id"))
- .where(F.col("min_employee_id") > F.lit(1))
- )
- dfs_employee = (
- self.df_sqlglot_employee.groupBy(self.df_sqlglot_employee.age)
- .agg(SF.min(self.df_sqlglot_employee.employee_id).alias("min_employee_id"))
- .where(SF.col("min_employee_id") > SF.lit(1))
- )
- self.compare_spark_with_sqlglot(df_employee, dfs_employee)
-
- def test_count(self):
- df = self.df_spark_employee.groupBy(self.df_spark_employee.age).count()
- dfs = self.df_sqlglot_employee.groupBy(self.df_sqlglot_employee.age).count()
- self.compare_spark_with_sqlglot(df, dfs)
-
- def test_mean(self):
- df = self.df_spark_employee.groupBy().mean("age", "store_id")
- dfs = self.df_sqlglot_employee.groupBy().mean("age", "store_id")
- self.compare_spark_with_sqlglot(df, dfs)
-
- def test_avg(self):
- df = self.df_spark_employee.groupBy("age").avg("store_id")
- dfs = self.df_sqlglot_employee.groupBy("age").avg("store_id")
- self.compare_spark_with_sqlglot(df, dfs)
-
- def test_max(self):
- df = self.df_spark_employee.groupBy("age").max("store_id")
- dfs = self.df_sqlglot_employee.groupBy("age").max("store_id")
- self.compare_spark_with_sqlglot(df, dfs)
-
- def test_min(self):
- df = self.df_spark_employee.groupBy("age").min("store_id")
- dfs = self.df_sqlglot_employee.groupBy("age").min("store_id")
- self.compare_spark_with_sqlglot(df, dfs)
-
- def test_sum(self):
- df = self.df_spark_employee.groupBy("age").sum("store_id")
- dfs = self.df_sqlglot_employee.groupBy("age").sum("store_id")
- self.compare_spark_with_sqlglot(df, dfs)
diff --git a/tests/dataframe/integration/test_session.py b/tests/dataframe/integration/test_session.py
deleted file mode 100644
index 3bb3e20..0000000
--- a/tests/dataframe/integration/test_session.py
+++ /dev/null
@@ -1,43 +0,0 @@
-from pyspark.sql import functions as F
-
-from sqlglot.dataframe.sql import functions as SF
-from tests.dataframe.integration.dataframe_validator import DataFrameValidator
-
-
-class TestSessionFunc(DataFrameValidator):
- def test_sql_simple_select(self):
- query = "SELECT fname, lname FROM employee"
- df = self.spark.sql(query)
- dfs = self.sqlglot.sql(query)
- self.compare_spark_with_sqlglot(df, dfs)
-
- def test_sql_with_join(self):
- query = """
- SELECT
- e.employee_id
- , s.store_id
- FROM
- employee e
- INNER JOIN
- store s
- ON
- e.store_id = s.store_id
- """
- df = (
- self.spark.sql(query)
- .groupBy(F.col("store_id"))
- .agg(F.countDistinct(F.col("employee_id")))
- )
- dfs = (
- self.sqlglot.sql(query)
- .groupBy(SF.col("store_id"))
- .agg(SF.countDistinct(SF.col("employee_id")))
- )
- self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
-
- def test_nameless_column(self):
- query = "SELECT MAX(age) FROM employee"
- df = self.spark.sql(query)
- dfs = self.sqlglot.sql(query)
- # Spark will alias the column to `max(age)` while sqlglot will alias to `_col_0` so their schemas will differ
- self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
diff --git a/tests/dataframe/unit/__init__.py b/tests/dataframe/unit/__init__.py
deleted file mode 100644
index e69de29..0000000
--- a/tests/dataframe/unit/__init__.py
+++ /dev/null
diff --git a/tests/dataframe/unit/dataframe_sql_validator.py b/tests/dataframe/unit/dataframe_sql_validator.py
deleted file mode 100644
index 4363b0d..0000000
--- a/tests/dataframe/unit/dataframe_sql_validator.py
+++ /dev/null
@@ -1,28 +0,0 @@
-from sqlglot.dataframe.sql import types
-from sqlglot.dataframe.sql.session import SparkSession
-from tests.dataframe.unit.dataframe_test_base import DataFrameTestBase
-
-
-class DataFrameSQLValidator(DataFrameTestBase):
- def setUp(self) -> None:
- super().setUp()
- self.spark = SparkSession()
- self.employee_schema = types.StructType(
- [
- types.StructField("employee_id", types.IntegerType(), False),
- types.StructField("fname", types.StringType(), False),
- types.StructField("lname", types.StringType(), False),
- types.StructField("age", types.IntegerType(), False),
- types.StructField("store_id", types.IntegerType(), False),
- ]
- )
- employee_data = [
- (1, "Jack", "Shephard", 37, 1),
- (2, "John", "Locke", 65, 1),
- (3, "Kate", "Austen", 37, 2),
- (4, "Claire", "Littleton", 27, 2),
- (5, "Hugo", "Reyes", 29, 100),
- ]
- self.df_employee = self.spark.createDataFrame(
- data=employee_data, schema=self.employee_schema
- )
diff --git a/tests/dataframe/unit/dataframe_test_base.py b/tests/dataframe/unit/dataframe_test_base.py
deleted file mode 100644
index 6b07df9..0000000
--- a/tests/dataframe/unit/dataframe_test_base.py
+++ /dev/null
@@ -1,23 +0,0 @@
-import typing as t
-import unittest
-
-import sqlglot
-from sqlglot import MappingSchema
-from sqlglot.dataframe.sql import SparkSession
-from sqlglot.dataframe.sql.dataframe import DataFrame
-from sqlglot.helper import ensure_list
-
-
-class DataFrameTestBase(unittest.TestCase):
- def setUp(self) -> None:
- sqlglot.schema = MappingSchema()
- SparkSession._instance = None
-
- def compare_sql(
- self, df: DataFrame, expected_statements: t.Union[str, t.List[str]], pretty=False
- ):
- actual_sqls = df.sql(pretty=pretty)
- expected_statements = ensure_list(expected_statements)
- self.assertEqual(len(expected_statements), len(actual_sqls))
- for expected, actual in zip(expected_statements, actual_sqls):
- self.assertEqual(expected, actual)
diff --git a/tests/dataframe/unit/test_column.py b/tests/dataframe/unit/test_column.py
deleted file mode 100644
index 833005b..0000000
--- a/tests/dataframe/unit/test_column.py
+++ /dev/null
@@ -1,174 +0,0 @@
-import datetime
-import unittest
-
-from sqlglot.dataframe.sql import functions as F
-from sqlglot.dataframe.sql.window import Window
-
-
-class TestDataframeColumn(unittest.TestCase):
- def test_eq(self):
- self.assertEqual("cola = 1", (F.col("cola") == 1).sql())
-
- def test_neq(self):
- self.assertEqual("cola <> 1", (F.col("cola") != 1).sql())
-
- def test_gt(self):
- self.assertEqual("cola > 1", (F.col("cola") > 1).sql())
-
- def test_lt(self):
- self.assertEqual("cola < 1", (F.col("cola") < 1).sql())
-
- def test_le(self):
- self.assertEqual("cola <= 1", (F.col("cola") <= 1).sql())
-
- def test_ge(self):
- self.assertEqual("cola >= 1", (F.col("cola") >= 1).sql())
-
- def test_and(self):
- self.assertEqual(
- "cola = colb AND colc = cold",
- ((F.col("cola") == F.col("colb")) & (F.col("colc") == F.col("cold"))).sql(),
- )
-
- def test_or(self):
- self.assertEqual(
- "cola = colb OR colc = cold",
- ((F.col("cola") == F.col("colb")) | (F.col("colc") == F.col("cold"))).sql(),
- )
-
- def test_mod(self):
- self.assertEqual("cola % 2", (F.col("cola") % 2).sql())
-
- def test_add(self):
- self.assertEqual("cola + 1", (F.col("cola") + 1).sql())
-
- def test_sub(self):
- self.assertEqual("cola - 1", (F.col("cola") - 1).sql())
-
- def test_mul(self):
- self.assertEqual("cola * 2", (F.col("cola") * 2).sql())
-
- def test_div(self):
- self.assertEqual("cola / 2", (F.col("cola") / 2).sql())
-
- def test_radd(self):
- self.assertEqual("1 + cola", (1 + F.col("cola")).sql())
-
- def test_rsub(self):
- self.assertEqual("1 - cola", (1 - F.col("cola")).sql())
-
- def test_rmul(self):
- self.assertEqual("1 * cola", (1 * F.col("cola")).sql())
-
- def test_rdiv(self):
- self.assertEqual("1 / cola", (1 / F.col("cola")).sql())
-
- def test_pow(self):
- self.assertEqual("POWER(cola, 2)", (F.col("cola") ** 2).sql())
-
- def test_rpow(self):
- self.assertEqual("POWER(2, cola)", (2 ** F.col("cola")).sql())
-
- def test_invert(self):
- self.assertEqual("NOT cola", (~F.col("cola")).sql())
-
- def test_startswith(self):
- self.assertEqual("STARTSWITH(cola, 'test')", F.col("cola").startswith("test").sql())
-
- def test_endswith(self):
- self.assertEqual("ENDSWITH(cola, 'test')", F.col("cola").endswith("test").sql())
-
- def test_rlike(self):
- self.assertEqual("cola RLIKE 'foo'", F.col("cola").rlike("foo").sql())
-
- def test_like(self):
- self.assertEqual("cola LIKE 'foo%'", F.col("cola").like("foo%").sql())
-
- def test_ilike(self):
- self.assertEqual("cola ILIKE 'foo%'", F.col("cola").ilike("foo%").sql())
-
- def test_substring(self):
- self.assertEqual("SUBSTRING(cola, 2, 3)", F.col("cola").substr(2, 3).sql())
-
- def test_isin(self):
- self.assertEqual("cola IN (1, 2, 3)", F.col("cola").isin([1, 2, 3]).sql())
- self.assertEqual("cola IN (1, 2, 3)", F.col("cola").isin(1, 2, 3).sql())
-
- def test_asc(self):
- self.assertEqual("cola ASC", F.col("cola").asc().sql())
-
- def test_desc(self):
- self.assertEqual("cola DESC", F.col("cola").desc().sql())
-
- def test_asc_nulls_first(self):
- self.assertEqual("cola ASC", F.col("cola").asc_nulls_first().sql())
-
- def test_asc_nulls_last(self):
- self.assertEqual("cola ASC NULLS LAST", F.col("cola").asc_nulls_last().sql())
-
- def test_desc_nulls_first(self):
- self.assertEqual("cola DESC NULLS FIRST", F.col("cola").desc_nulls_first().sql())
-
- def test_desc_nulls_last(self):
- self.assertEqual("cola DESC", F.col("cola").desc_nulls_last().sql())
-
- def test_when_otherwise(self):
- self.assertEqual("CASE WHEN cola = 1 THEN 2 END", F.when(F.col("cola") == 1, 2).sql())
- self.assertEqual(
- "CASE WHEN cola = 1 THEN 2 END", F.col("cola").when(F.col("cola") == 1, 2).sql()
- )
- self.assertEqual(
- "CASE WHEN cola = 1 THEN 2 WHEN colb = 2 THEN 3 END",
- (F.when(F.col("cola") == 1, 2).when(F.col("colb") == 2, 3)).sql(),
- )
- self.assertEqual(
- "CASE WHEN cola = 1 THEN 2 WHEN colb = 2 THEN 3 END",
- F.col("cola").when(F.col("cola") == 1, 2).when(F.col("colb") == 2, 3).sql(),
- )
- self.assertEqual(
- "CASE WHEN cola = 1 THEN 2 WHEN colb = 2 THEN 3 ELSE 4 END",
- F.when(F.col("cola") == 1, 2).when(F.col("colb") == 2, 3).otherwise(4).sql(),
- )
-
- def test_is_null(self):
- self.assertEqual("cola IS NULL", F.col("cola").isNull().sql())
-
- def test_is_not_null(self):
- self.assertEqual("NOT cola IS NULL", F.col("cola").isNotNull().sql())
-
- def test_cast(self):
- self.assertEqual("CAST(cola AS INT)", F.col("cola").cast("INT").sql())
-
- def test_alias(self):
- self.assertEqual("cola AS new_name", F.col("cola").alias("new_name").sql())
-
- def test_between(self):
- self.assertEqual("cola BETWEEN 1 AND 3", F.col("cola").between(1, 3).sql())
- self.assertEqual("cola BETWEEN 10.1 AND 12.1", F.col("cola").between(10.1, 12.1).sql())
- self.assertEqual(
- "cola BETWEEN CAST('2022-01-01' AS DATE) AND CAST('2022-03-01' AS DATE)",
- F.col("cola").between(datetime.date(2022, 1, 1), datetime.date(2022, 3, 1)).sql(),
- )
- self.assertEqual(
- "cola BETWEEN CAST('2022-01-01 01:01:01+00:00' AS TIMESTAMP) "
- "AND CAST('2022-03-01 01:01:01+00:00' AS TIMESTAMP)",
- F.col("cola")
- .between(datetime.datetime(2022, 1, 1, 1, 1, 1), datetime.datetime(2022, 3, 1, 1, 1, 1))
- .sql(),
- )
-
- def test_over(self):
- over_rows = F.sum("cola").over(
- Window.partitionBy("colb").orderBy("colc").rowsBetween(1, Window.unboundedFollowing)
- )
- self.assertEqual(
- "SUM(cola) OVER (PARTITION BY colb ORDER BY colc ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)",
- over_rows.sql(),
- )
- over_range = F.sum("cola").over(
- Window.partitionBy("colb").orderBy("colc").rangeBetween(1, Window.unboundedFollowing)
- )
- self.assertEqual(
- "SUM(cola) OVER (PARTITION BY colb ORDER BY colc RANGE BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)",
- over_range.sql(),
- )
diff --git a/tests/dataframe/unit/test_dataframe.py b/tests/dataframe/unit/test_dataframe.py
deleted file mode 100644
index 24850bc..0000000
--- a/tests/dataframe/unit/test_dataframe.py
+++ /dev/null
@@ -1,43 +0,0 @@
-from sqlglot import expressions as exp
-from sqlglot.dataframe.sql.dataframe import DataFrame
-from tests.dataframe.unit.dataframe_sql_validator import DataFrameSQLValidator
-
-
-class TestDataframe(DataFrameSQLValidator):
- maxDiff = None
-
- def test_hash_select_expression(self):
- expression = exp.select("cola").from_("table")
- self.assertEqual("t17051", DataFrame._create_hash_from_expression(expression))
-
- def test_columns(self):
- self.assertEqual(
- ["employee_id", "fname", "lname", "age", "store_id"], self.df_employee.columns
- )
-
- def test_cache(self):
- df = self.df_employee.select("fname").cache()
- expected_statements = [
- "DROP VIEW IF EXISTS t31563",
- "CACHE LAZY TABLE t31563 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT CAST(`a1`.`fname` AS STRING) AS `fname` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
- "SELECT `t31563`.`fname` AS `fname` FROM `t31563` AS `t31563`",
- ]
- self.compare_sql(df, expected_statements)
-
- def test_persist_default(self):
- df = self.df_employee.select("fname").persist()
- expected_statements = [
- "DROP VIEW IF EXISTS t31563",
- "CACHE LAZY TABLE t31563 OPTIONS('storageLevel' = 'MEMORY_AND_DISK_SER') AS SELECT CAST(`a1`.`fname` AS STRING) AS `fname` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
- "SELECT `t31563`.`fname` AS `fname` FROM `t31563` AS `t31563`",
- ]
- self.compare_sql(df, expected_statements)
-
- def test_persist_storagelevel(self):
- df = self.df_employee.select("fname").persist("DISK_ONLY_2")
- expected_statements = [
- "DROP VIEW IF EXISTS t31563",
- "CACHE LAZY TABLE t31563 OPTIONS('storageLevel' = 'DISK_ONLY_2') AS SELECT CAST(`a1`.`fname` AS STRING) AS `fname` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
- "SELECT `t31563`.`fname` AS `fname` FROM `t31563` AS `t31563`",
- ]
- self.compare_sql(df, expected_statements)
diff --git a/tests/dataframe/unit/test_dataframe_writer.py b/tests/dataframe/unit/test_dataframe_writer.py
deleted file mode 100644
index 303d2f9..0000000
--- a/tests/dataframe/unit/test_dataframe_writer.py
+++ /dev/null
@@ -1,95 +0,0 @@
-from unittest import mock
-
-import sqlglot
-from sqlglot.schema import MappingSchema
-from tests.dataframe.unit.dataframe_sql_validator import DataFrameSQLValidator
-
-
-class TestDataFrameWriter(DataFrameSQLValidator):
- maxDiff = None
-
- def test_insertInto_full_path(self):
- df = self.df_employee.write.insertInto("catalog.db.table_name")
- expected = "INSERT INTO catalog.db.table_name SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
- self.compare_sql(df, expected)
-
- def test_insertInto_db_table(self):
- df = self.df_employee.write.insertInto("db.table_name")
- expected = "INSERT INTO db.table_name SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
- self.compare_sql(df, expected)
-
- def test_insertInto_table(self):
- df = self.df_employee.write.insertInto("table_name")
- expected = "INSERT INTO table_name SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
- self.compare_sql(df, expected)
-
- def test_insertInto_overwrite(self):
- df = self.df_employee.write.insertInto("table_name", overwrite=True)
- expected = "INSERT OVERWRITE TABLE table_name SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
- self.compare_sql(df, expected)
-
- @mock.patch("sqlglot.schema", MappingSchema())
- def test_insertInto_byName(self):
- sqlglot.schema.add_table("table_name", {"employee_id": "INT"}, dialect="spark")
- df = self.df_employee.write.byName.insertInto("table_name")
- expected = "INSERT INTO table_name SELECT `a1`.`employee_id` AS `employee_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
- self.compare_sql(df, expected)
-
- def test_insertInto_cache(self):
- df = self.df_employee.cache().write.insertInto("table_name")
- expected_statements = [
- "DROP VIEW IF EXISTS t12441",
- "CACHE LAZY TABLE t12441 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
- "INSERT INTO table_name SELECT `t12441`.`employee_id` AS `employee_id`, `t12441`.`fname` AS `fname`, `t12441`.`lname` AS `lname`, `t12441`.`age` AS `age`, `t12441`.`store_id` AS `store_id` FROM `t12441` AS `t12441`",
- ]
- self.compare_sql(df, expected_statements)
-
- def test_saveAsTable_format(self):
- with self.assertRaises(NotImplementedError):
- self.df_employee.write.saveAsTable("table_name", format="parquet").sql(pretty=False)[0]
-
- def test_saveAsTable_append(self):
- df = self.df_employee.write.saveAsTable("table_name", mode="append")
- expected = "INSERT INTO table_name SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
- self.compare_sql(df, expected)
-
- def test_saveAsTable_overwrite(self):
- df = self.df_employee.write.saveAsTable("table_name", mode="overwrite")
- expected = "CREATE OR REPLACE TABLE table_name AS SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
- self.compare_sql(df, expected)
-
- def test_saveAsTable_error(self):
- df = self.df_employee.write.saveAsTable("table_name", mode="error")
- expected = "CREATE TABLE table_name AS SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
- self.compare_sql(df, expected)
-
- def test_saveAsTable_ignore(self):
- df = self.df_employee.write.saveAsTable("table_name", mode="ignore")
- expected = "CREATE TABLE IF NOT EXISTS table_name AS SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
- self.compare_sql(df, expected)
-
- def test_mode_standalone(self):
- df = self.df_employee.write.mode("ignore").saveAsTable("table_name")
- expected = "CREATE TABLE IF NOT EXISTS table_name AS SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
- self.compare_sql(df, expected)
-
- def test_mode_override(self):
- df = self.df_employee.write.mode("ignore").saveAsTable("table_name", mode="overwrite")
- expected = "CREATE OR REPLACE TABLE table_name AS SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
- self.compare_sql(df, expected)
-
- def test_saveAsTable_cache(self):
- df = self.df_employee.cache().write.saveAsTable("table_name")
- expected_statements = [
- "DROP VIEW IF EXISTS t12441",
- "CACHE LAZY TABLE t12441 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
- "CREATE TABLE table_name AS SELECT `t12441`.`employee_id` AS `employee_id`, `t12441`.`fname` AS `fname`, `t12441`.`lname` AS `lname`, `t12441`.`age` AS `age`, `t12441`.`store_id` AS `store_id` FROM `t12441` AS `t12441`",
- ]
- self.compare_sql(df, expected_statements)
-
- def test_quotes(self):
- sqlglot.schema.add_table("`Test`", {"`ID`": "STRING"}, dialect="spark")
- df = self.spark.table("`Test`")
- self.compare_sql(
- df.select(df["`ID`"]), ["SELECT `test`.`id` AS `id` FROM `test` AS `test`"]
- )
diff --git a/tests/dataframe/unit/test_functions.py b/tests/dataframe/unit/test_functions.py
deleted file mode 100644
index 884cded..0000000
--- a/tests/dataframe/unit/test_functions.py
+++ /dev/null
@@ -1,1632 +0,0 @@
-import datetime
-import inspect
-import unittest
-
-from sqlglot import expressions as exp, parse_one
-from sqlglot.dataframe.sql import functions as SF
-from sqlglot.errors import ErrorLevel
-
-
-class TestFunctions(unittest.TestCase):
- def test_invoke_anonymous(self):
- for name, func in inspect.getmembers(SF, inspect.isfunction):
- with self.subTest(f"{name} should not invoke anonymous_function"):
- if "invoke_anonymous_function" in inspect.getsource(func):
- func = parse_one(f"{name}()", read="spark", error_level=ErrorLevel.IGNORE)
- self.assertIsInstance(func, exp.Anonymous)
-
- def test_lit(self):
- test_str = SF.lit("test")
- self.assertEqual("'test'", test_str.sql())
- test_int = SF.lit(30)
- self.assertEqual("30", test_int.sql())
- test_float = SF.lit(10.10)
- self.assertEqual("10.1", test_float.sql())
- test_bool = SF.lit(False)
- self.assertEqual("FALSE", test_bool.sql())
- test_null = SF.lit(None)
- self.assertEqual("NULL", test_null.sql())
- test_date = SF.lit(datetime.date(2022, 1, 1))
- self.assertEqual("CAST('2022-01-01' AS DATE)", test_date.sql())
- test_datetime = SF.lit(datetime.datetime(2022, 1, 1, 1, 1, 1))
- self.assertEqual("CAST('2022-01-01 01:01:01+00:00' AS TIMESTAMP)", test_datetime.sql())
- test_dict = SF.lit({"cola": 1, "colb": "test"})
- self.assertEqual("STRUCT(1 AS cola, 'test' AS colb)", test_dict.sql())
-
- def test_col(self):
- test_col = SF.col("cola")
- self.assertEqual("cola", test_col.sql())
- test_col_with_table = SF.col("table.cola")
- self.assertEqual("table.cola", test_col_with_table.sql())
- test_col_on_col = SF.col(test_col)
- self.assertEqual("cola", test_col_on_col.sql())
- test_int = SF.col(10)
- self.assertEqual("10", test_int.sql())
- test_float = SF.col(10.10)
- self.assertEqual("10.1", test_float.sql())
- test_bool = SF.col(True)
- self.assertEqual("TRUE", test_bool.sql())
- test_array = SF.col([1, 2, "3"])
- self.assertEqual("ARRAY(1, 2, '3')", test_array.sql())
- test_date = SF.col(datetime.date(2022, 1, 1))
- self.assertEqual("CAST('2022-01-01' AS DATE)", test_date.sql())
- test_datetime = SF.col(datetime.datetime(2022, 1, 1, 1, 1, 1))
- self.assertEqual("CAST('2022-01-01 01:01:01+00:00' AS TIMESTAMP)", test_datetime.sql())
- test_dict = SF.col({"cola": 1, "colb": "test"})
- self.assertEqual("STRUCT(1 AS cola, 'test' AS colb)", test_dict.sql())
-
- def test_asc(self):
- asc_str = SF.asc("cola")
- # ASC is removed from output since that is default so we can't check sql
- self.assertIsInstance(asc_str.expression, exp.Ordered)
- asc_col = SF.asc(SF.col("cola"))
- self.assertIsInstance(asc_col.expression, exp.Ordered)
-
- def test_desc(self):
- desc_str = SF.desc("cola")
- self.assertEqual("cola DESC", desc_str.sql())
- desc_col = SF.desc(SF.col("cola"))
- self.assertEqual("cola DESC", desc_col.sql())
-
- def test_sqrt(self):
- col_str = SF.sqrt("cola")
- self.assertEqual("SQRT(cola)", col_str.sql())
- col = SF.sqrt(SF.col("cola"))
- self.assertEqual("SQRT(cola)", col.sql())
-
- def test_abs(self):
- col_str = SF.abs("cola")
- self.assertEqual("ABS(cola)", col_str.sql())
- col = SF.abs(SF.col("cola"))
- self.assertEqual("ABS(cola)", col.sql())
-
- def test_max(self):
- col_str = SF.max("cola")
- self.assertEqual("MAX(cola)", col_str.sql())
- col = SF.max(SF.col("cola"))
- self.assertEqual("MAX(cola)", col.sql())
-
- def test_min(self):
- col_str = SF.min("cola")
- self.assertEqual("MIN(cola)", col_str.sql())
- col = SF.min(SF.col("cola"))
- self.assertEqual("MIN(cola)", col.sql())
-
- def test_max_by(self):
- col_str = SF.max_by("cola", "colb")
- self.assertEqual("MAX_BY(cola, colb)", col_str.sql())
- col = SF.max_by(SF.col("cola"), SF.col("colb"))
- self.assertEqual("MAX_BY(cola, colb)", col.sql())
-
- def test_min_by(self):
- col_str = SF.min_by("cola", "colb")
- self.assertEqual("MIN_BY(cola, colb)", col_str.sql())
- col = SF.min_by(SF.col("cola"), SF.col("colb"))
- self.assertEqual("MIN_BY(cola, colb)", col.sql())
-
- def test_count(self):
- col_str = SF.count("cola")
- self.assertEqual("COUNT(cola)", col_str.sql())
- col = SF.count(SF.col("cola"))
- self.assertEqual("COUNT(cola)", col.sql())
-
- def test_sum(self):
- col_str = SF.sum("cola")
- self.assertEqual("SUM(cola)", col_str.sql())
- col = SF.sum(SF.col("cola"))
- self.assertEqual("SUM(cola)", col.sql())
-
- def test_avg(self):
- col_str = SF.avg("cola")
- self.assertEqual("AVG(cola)", col_str.sql())
- col = SF.avg(SF.col("cola"))
- self.assertEqual("AVG(cola)", col.sql())
-
- def test_mean(self):
- col_str = SF.mean("cola")
- self.assertEqual("MEAN(cola)", col_str.sql())
- col = SF.mean(SF.col("cola"))
- self.assertEqual("MEAN(cola)", col.sql())
-
- def test_sum_distinct(self):
- with self.assertRaises(NotImplementedError):
- SF.sum_distinct("cola")
- with self.assertRaises(NotImplementedError):
- SF.sumDistinct("cola")
-
- def test_product(self):
- with self.assertRaises(NotImplementedError):
- SF.product("cola")
- with self.assertRaises(NotImplementedError):
- SF.product("cola")
-
- def test_acos(self):
- col_str = SF.acos("cola")
- self.assertEqual("ACOS(cola)", col_str.sql())
- col = SF.acos(SF.col("cola"))
- self.assertEqual("ACOS(cola)", col.sql())
-
- def test_acosh(self):
- col_str = SF.acosh("cola")
- self.assertEqual("ACOSH(cola)", col_str.sql())
- col = SF.acosh(SF.col("cola"))
- self.assertEqual("ACOSH(cola)", col.sql())
-
- def test_asin(self):
- col_str = SF.asin("cola")
- self.assertEqual("ASIN(cola)", col_str.sql())
- col = SF.asin(SF.col("cola"))
- self.assertEqual("ASIN(cola)", col.sql())
-
- def test_asinh(self):
- col_str = SF.asinh("cola")
- self.assertEqual("ASINH(cola)", col_str.sql())
- col = SF.asinh(SF.col("cola"))
- self.assertEqual("ASINH(cola)", col.sql())
-
- def test_atan(self):
- col_str = SF.atan("cola")
- self.assertEqual("ATAN(cola)", col_str.sql())
- col = SF.atan(SF.col("cola"))
- self.assertEqual("ATAN(cola)", col.sql())
-
- def test_atan2(self):
- col_str = SF.atan2("cola", "colb")
- self.assertEqual("ATAN2(cola, colb)", col_str.sql())
- col = SF.atan2(SF.col("cola"), SF.col("colb"))
- self.assertEqual("ATAN2(cola, colb)", col.sql())
- col_float = SF.atan2(10.10, "colb")
- self.assertEqual("ATAN2(10.1, colb)", col_float.sql())
- col_float2 = SF.atan2("cola", 10.10)
- self.assertEqual("ATAN2(cola, 10.1)", col_float2.sql())
-
- def test_atanh(self):
- col_str = SF.atanh("cola")
- self.assertEqual("ATANH(cola)", col_str.sql())
- col = SF.atanh(SF.col("cola"))
- self.assertEqual("ATANH(cola)", col.sql())
-
- def test_cbrt(self):
- col_str = SF.cbrt("cola")
- self.assertEqual("CBRT(cola)", col_str.sql())
- col = SF.cbrt(SF.col("cola"))
- self.assertEqual("CBRT(cola)", col.sql())
-
- def test_ceil(self):
- col_str = SF.ceil("cola")
- self.assertEqual("CEIL(cola)", col_str.sql())
- col = SF.ceil(SF.col("cola"))
- self.assertEqual("CEIL(cola)", col.sql())
-
- def test_cos(self):
- col_str = SF.cos("cola")
- self.assertEqual("COS(cola)", col_str.sql())
- col = SF.cos(SF.col("cola"))
- self.assertEqual("COS(cola)", col.sql())
-
- def test_cosh(self):
- col_str = SF.cosh("cola")
- self.assertEqual("COSH(cola)", col_str.sql())
- col = SF.cosh(SF.col("cola"))
- self.assertEqual("COSH(cola)", col.sql())
-
- def test_cot(self):
- col_str = SF.cot("cola")
- self.assertEqual("COT(cola)", col_str.sql())
- col = SF.cot(SF.col("cola"))
- self.assertEqual("COT(cola)", col.sql())
-
- def test_csc(self):
- col_str = SF.csc("cola")
- self.assertEqual("CSC(cola)", col_str.sql())
- col = SF.csc(SF.col("cola"))
- self.assertEqual("CSC(cola)", col.sql())
-
- def test_exp(self):
- col_str = SF.exp("cola")
- self.assertEqual("EXP(cola)", col_str.sql())
- col = SF.exp(SF.col("cola"))
- self.assertEqual("EXP(cola)", col.sql())
-
- def test_expm1(self):
- col_str = SF.expm1("cola")
- self.assertEqual("EXPM1(cola)", col_str.sql())
- col = SF.expm1(SF.col("cola"))
- self.assertEqual("EXPM1(cola)", col.sql())
-
- def test_floor(self):
- col_str = SF.floor("cola")
- self.assertEqual("FLOOR(cola)", col_str.sql())
- col = SF.floor(SF.col("cola"))
- self.assertEqual("FLOOR(cola)", col.sql())
-
- def test_log(self):
- col_str = SF.log("cola")
- self.assertEqual("LN(cola)", col_str.sql())
- col = SF.log(SF.col("cola"))
- self.assertEqual("LN(cola)", col.sql())
- col_arg = SF.log(10.0, "age")
- self.assertEqual("LOG(10.0, age)", col_arg.sql())
-
- def test_log10(self):
- col_str = SF.log10("cola")
- self.assertEqual("LOG(10, cola)", col_str.sql())
- col = SF.log10(SF.col("cola"))
- self.assertEqual("LOG(10, cola)", col.sql())
-
- def test_log1p(self):
- col_str = SF.log1p("cola")
- self.assertEqual("LOG1P(cola)", col_str.sql())
- col = SF.log1p(SF.col("cola"))
- self.assertEqual("LOG1P(cola)", col.sql())
-
- def test_log2(self):
- col_str = SF.log2("cola")
- self.assertEqual("LOG(2, cola)", col_str.sql())
- col = SF.log2(SF.col("cola"))
- self.assertEqual("LOG(2, cola)", col.sql())
-
- def test_rint(self):
- col_str = SF.rint("cola")
- self.assertEqual("RINT(cola)", col_str.sql())
- col = SF.rint(SF.col("cola"))
- self.assertEqual("RINT(cola)", col.sql())
-
- def test_sec(self):
- col_str = SF.sec("cola")
- self.assertEqual("SEC(cola)", col_str.sql())
- col = SF.sec(SF.col("cola"))
- self.assertEqual("SEC(cola)", col.sql())
-
- def test_signum(self):
- col_str = SF.signum("cola")
- self.assertEqual("SIGN(cola)", col_str.sql())
- col = SF.signum(SF.col("cola"))
- self.assertEqual("SIGN(cola)", col.sql())
-
- def test_sin(self):
- col_str = SF.sin("cola")
- self.assertEqual("SIN(cola)", col_str.sql())
- col = SF.sin(SF.col("cola"))
- self.assertEqual("SIN(cola)", col.sql())
-
- def test_sinh(self):
- col_str = SF.sinh("cola")
- self.assertEqual("SINH(cola)", col_str.sql())
- col = SF.sinh(SF.col("cola"))
- self.assertEqual("SINH(cola)", col.sql())
-
- def test_tan(self):
- col_str = SF.tan("cola")
- self.assertEqual("TAN(cola)", col_str.sql())
- col = SF.tan(SF.col("cola"))
- self.assertEqual("TAN(cola)", col.sql())
-
- def test_tanh(self):
- col_str = SF.tanh("cola")
- self.assertEqual("TANH(cola)", col_str.sql())
- col = SF.tanh(SF.col("cola"))
- self.assertEqual("TANH(cola)", col.sql())
-
- def test_degrees(self):
- col_str = SF.degrees("cola")
- self.assertEqual("DEGREES(cola)", col_str.sql())
- col = SF.degrees(SF.col("cola"))
- self.assertEqual("DEGREES(cola)", col.sql())
- col_legacy = SF.toDegrees(SF.col("cola"))
- self.assertEqual("DEGREES(cola)", col_legacy.sql())
-
- def test_radians(self):
- col_str = SF.radians("cola")
- self.assertEqual("RADIANS(cola)", col_str.sql())
- col = SF.radians(SF.col("cola"))
- self.assertEqual("RADIANS(cola)", col.sql())
- col_legacy = SF.toRadians(SF.col("cola"))
- self.assertEqual("RADIANS(cola)", col_legacy.sql())
-
- def test_bitwise_not(self):
- col_str = SF.bitwise_not("cola")
- self.assertEqual("~cola", col_str.sql())
- col = SF.bitwise_not(SF.col("cola"))
- self.assertEqual("~cola", col.sql())
- col_legacy = SF.bitwiseNOT(SF.col("cola"))
- self.assertEqual("~cola", col_legacy.sql())
-
- def test_asc_nulls_first(self):
- col_str = SF.asc_nulls_first("cola")
- self.assertIsInstance(col_str.expression, exp.Ordered)
- self.assertEqual("cola ASC", col_str.sql())
- col = SF.asc_nulls_first(SF.col("cola"))
- self.assertIsInstance(col.expression, exp.Ordered)
- self.assertEqual("cola ASC", col.sql())
-
- def test_asc_nulls_last(self):
- col_str = SF.asc_nulls_last("cola")
- self.assertIsInstance(col_str.expression, exp.Ordered)
- self.assertEqual("cola ASC NULLS LAST", col_str.sql())
- col = SF.asc_nulls_last(SF.col("cola"))
- self.assertIsInstance(col.expression, exp.Ordered)
- self.assertEqual("cola ASC NULLS LAST", col.sql())
-
- def test_desc_nulls_first(self):
- col_str = SF.desc_nulls_first("cola")
- self.assertIsInstance(col_str.expression, exp.Ordered)
- self.assertEqual("cola DESC NULLS FIRST", col_str.sql())
- col = SF.desc_nulls_first(SF.col("cola"))
- self.assertIsInstance(col.expression, exp.Ordered)
- self.assertEqual("cola DESC NULLS FIRST", col.sql())
-
- def test_desc_nulls_last(self):
- col_str = SF.desc_nulls_last("cola")
- self.assertIsInstance(col_str.expression, exp.Ordered)
- self.assertEqual("cola DESC", col_str.sql())
- col = SF.desc_nulls_last(SF.col("cola"))
- self.assertIsInstance(col.expression, exp.Ordered)
- self.assertEqual("cola DESC", col.sql())
-
- def test_stddev(self):
- col_str = SF.stddev("cola")
- self.assertEqual("STDDEV(cola)", col_str.sql())
- col = SF.stddev(SF.col("cola"))
- self.assertEqual("STDDEV(cola)", col.sql())
-
- def test_stddev_samp(self):
- col_str = SF.stddev_samp("cola")
- self.assertEqual("STDDEV_SAMP(cola)", col_str.sql())
- col = SF.stddev_samp(SF.col("cola"))
- self.assertEqual("STDDEV_SAMP(cola)", col.sql())
-
- def test_stddev_pop(self):
- col_str = SF.stddev_pop("cola")
- self.assertEqual("STDDEV_POP(cola)", col_str.sql())
- col = SF.stddev_pop(SF.col("cola"))
- self.assertEqual("STDDEV_POP(cola)", col.sql())
-
- def test_variance(self):
- col_str = SF.variance("cola")
- self.assertEqual("VARIANCE(cola)", col_str.sql())
- col = SF.variance(SF.col("cola"))
- self.assertEqual("VARIANCE(cola)", col.sql())
-
- def test_var_samp(self):
- col_str = SF.var_samp("cola")
- self.assertEqual("VARIANCE(cola)", col_str.sql())
- col = SF.var_samp(SF.col("cola"))
- self.assertEqual("VARIANCE(cola)", col.sql())
-
- def test_var_pop(self):
- col_str = SF.var_pop("cola")
- self.assertEqual("VAR_POP(cola)", col_str.sql())
- col = SF.var_pop(SF.col("cola"))
- self.assertEqual("VAR_POP(cola)", col.sql())
-
- def test_skewness(self):
- col_str = SF.skewness("cola")
- self.assertEqual("SKEWNESS(cola)", col_str.sql())
- col = SF.skewness(SF.col("cola"))
- self.assertEqual("SKEWNESS(cola)", col.sql())
-
- def test_kurtosis(self):
- col_str = SF.kurtosis("cola")
- self.assertEqual("KURTOSIS(cola)", col_str.sql())
- col = SF.kurtosis(SF.col("cola"))
- self.assertEqual("KURTOSIS(cola)", col.sql())
-
- def test_collect_list(self):
- col_str = SF.collect_list("cola")
- self.assertEqual("COLLECT_LIST(cola)", col_str.sql())
- col = SF.collect_list(SF.col("cola"))
- self.assertEqual("COLLECT_LIST(cola)", col.sql())
-
- def test_collect_set(self):
- col_str = SF.collect_set("cola")
- self.assertEqual("COLLECT_SET(cola)", col_str.sql())
- col = SF.collect_set(SF.col("cola"))
- self.assertEqual("COLLECT_SET(cola)", col.sql())
-
- def test_hypot(self):
- col_str = SF.hypot("cola", "colb")
- self.assertEqual("HYPOT(cola, colb)", col_str.sql())
- col = SF.hypot(SF.col("cola"), SF.col("colb"))
- self.assertEqual("HYPOT(cola, colb)", col.sql())
- col_float = SF.hypot(10.10, "colb")
- self.assertEqual("HYPOT(10.1, colb)", col_float.sql())
- col_float2 = SF.hypot("cola", 10.10)
- self.assertEqual("HYPOT(cola, 10.1)", col_float2.sql())
-
- def test_pow(self):
- col_str = SF.pow("cola", "colb")
- self.assertEqual("POWER(cola, colb)", col_str.sql())
- col = SF.pow(SF.col("cola"), SF.col("colb"))
- self.assertEqual("POWER(cola, colb)", col.sql())
- col_float = SF.pow(10.10, "colb")
- self.assertEqual("POWER(10.1, colb)", col_float.sql())
- col_float2 = SF.pow("cola", 10.10)
- self.assertEqual("POWER(cola, 10.1)", col_float2.sql())
-
- def test_row_number(self):
- col_str = SF.row_number()
- self.assertEqual("ROW_NUMBER()", col_str.sql())
- col = SF.row_number()
- self.assertEqual("ROW_NUMBER()", col.sql())
-
- def test_dense_rank(self):
- col_str = SF.dense_rank()
- self.assertEqual("DENSE_RANK()", col_str.sql())
- col = SF.dense_rank()
- self.assertEqual("DENSE_RANK()", col.sql())
-
- def test_rank(self):
- col_str = SF.rank()
- self.assertEqual("RANK()", col_str.sql())
- col = SF.rank()
- self.assertEqual("RANK()", col.sql())
-
- def test_cume_dist(self):
- col_str = SF.cume_dist()
- self.assertEqual("CUME_DIST()", col_str.sql())
- col = SF.cume_dist()
- self.assertEqual("CUME_DIST()", col.sql())
-
- def test_percent_rank(self):
- col_str = SF.percent_rank()
- self.assertEqual("PERCENT_RANK()", col_str.sql())
- col = SF.percent_rank()
- self.assertEqual("PERCENT_RANK()", col.sql())
-
- def test_approx_count_distinct(self):
- col_str = SF.approx_count_distinct("cola")
- self.assertEqual("APPROX_COUNT_DISTINCT(cola)", col_str.sql())
- col_str_with_accuracy = SF.approx_count_distinct("cola", 0.05)
- self.assertEqual("APPROX_COUNT_DISTINCT(cola, 0.05)", col_str_with_accuracy.sql())
- col = SF.approx_count_distinct(SF.col("cola"))
- self.assertEqual("APPROX_COUNT_DISTINCT(cola)", col.sql())
- col_with_accuracy = SF.approx_count_distinct(SF.col("cola"), 0.05)
- self.assertEqual("APPROX_COUNT_DISTINCT(cola, 0.05)", col_with_accuracy.sql())
- col_legacy = SF.approxCountDistinct(SF.col("cola"))
- self.assertEqual("APPROX_COUNT_DISTINCT(cola)", col_legacy.sql())
-
- def test_coalesce(self):
- col_str = SF.coalesce("cola", "colb", "colc")
- self.assertEqual("COALESCE(cola, colb, colc)", col_str.sql())
- col = SF.coalesce(SF.col("cola"), "colb", SF.col("colc"))
- self.assertEqual("COALESCE(cola, colb, colc)", col.sql())
- col_single = SF.coalesce("cola")
- self.assertEqual("COALESCE(cola)", col_single.sql())
-
- def test_corr(self):
- col_str = SF.corr("cola", "colb")
- self.assertEqual("CORR(cola, colb)", col_str.sql())
- col = SF.corr(SF.col("cola"), "colb")
- self.assertEqual("CORR(cola, colb)", col.sql())
-
- def test_covar_pop(self):
- col_str = SF.covar_pop("cola", "colb")
- self.assertEqual("COVAR_POP(cola, colb)", col_str.sql())
- col = SF.covar_pop(SF.col("cola"), "colb")
- self.assertEqual("COVAR_POP(cola, colb)", col.sql())
-
- def test_covar_samp(self):
- col_str = SF.covar_samp("cola", "colb")
- self.assertEqual("COVAR_SAMP(cola, colb)", col_str.sql())
- col = SF.covar_samp(SF.col("cola"), "colb")
- self.assertEqual("COVAR_SAMP(cola, colb)", col.sql())
-
- def test_count_distinct(self):
- col_str = SF.count_distinct("cola")
- self.assertEqual("COUNT(DISTINCT cola)", col_str.sql())
- col = SF.count_distinct(SF.col("cola"))
- self.assertEqual("COUNT(DISTINCT cola)", col.sql())
- col_legacy = SF.countDistinct(SF.col("cola"))
- self.assertEqual("COUNT(DISTINCT cola)", col_legacy.sql())
- col_multiple = SF.count_distinct(SF.col("cola"), SF.col("colb"))
- self.assertEqual("COUNT(DISTINCT cola, colb)", col_multiple.sql())
-
- def test_first(self):
- col_str = SF.first("cola")
- self.assertEqual("FIRST(cola)", col_str.sql())
- col = SF.first(SF.col("cola"))
- self.assertEqual("FIRST(cola)", col.sql())
- ignore_nulls = SF.first("cola", True)
- self.assertEqual("FIRST(cola) IGNORE NULLS", ignore_nulls.sql())
-
- def test_grouping_id(self):
- col_str = SF.grouping_id("cola", "colb")
- self.assertEqual("GROUPING_ID(cola, colb)", col_str.sql())
- col = SF.grouping_id(SF.col("cola"), SF.col("colb"))
- self.assertEqual("GROUPING_ID(cola, colb)", col.sql())
- col_grouping_no_arg = SF.grouping_id()
- self.assertEqual("GROUPING_ID()", col_grouping_no_arg.sql())
- col_grouping_single_arg = SF.grouping_id("cola")
- self.assertEqual("GROUPING_ID(cola)", col_grouping_single_arg.sql())
-
- def test_input_file_name(self):
- col = SF.input_file_name()
- self.assertEqual("INPUT_FILE_NAME()", col.sql())
-
- def test_isnan(self):
- col_str = SF.isnan("cola")
- self.assertEqual("ISNAN(cola)", col_str.sql())
- col = SF.isnan(SF.col("cola"))
- self.assertEqual("ISNAN(cola)", col.sql())
-
- def test_isnull(self):
- col_str = SF.isnull("cola")
- self.assertEqual("ISNULL(cola)", col_str.sql())
- col = SF.isnull(SF.col("cola"))
- self.assertEqual("ISNULL(cola)", col.sql())
-
- def test_last(self):
- col_str = SF.last("cola")
- self.assertEqual("LAST(cola)", col_str.sql())
- col = SF.last(SF.col("cola"))
- self.assertEqual("LAST(cola)", col.sql())
- ignore_nulls = SF.last("cola", True)
- self.assertEqual("LAST(cola) IGNORE NULLS", ignore_nulls.sql())
-
- def test_monotonically_increasing_id(self):
- col = SF.monotonically_increasing_id()
- self.assertEqual("MONOTONICALLY_INCREASING_ID()", col.sql())
-
- def test_nanvl(self):
- col_str = SF.nanvl("cola", "colb")
- self.assertEqual("NANVL(cola, colb)", col_str.sql())
- col = SF.nanvl(SF.col("cola"), SF.col("colb"))
- self.assertEqual("NANVL(cola, colb)", col.sql())
-
- def test_percentile_approx(self):
- col_str = SF.percentile_approx("cola", [0.5, 0.4, 0.1])
- self.assertEqual("PERCENTILE_APPROX(cola, ARRAY(0.5, 0.4, 0.1))", col_str.sql())
- col = SF.percentile_approx(SF.col("cola"), [0.5, 0.4, 0.1])
- self.assertEqual("PERCENTILE_APPROX(cola, ARRAY(0.5, 0.4, 0.1))", col.sql())
- col_accuracy = SF.percentile_approx("cola", 0.1, 100)
- self.assertEqual("PERCENTILE_APPROX(cola, 0.1, 100)", col_accuracy.sql())
-
- def test_rand(self):
- col_str = SF.rand(SF.lit(0))
- self.assertEqual("RAND(0)", col_str.sql())
- col = SF.rand(SF.lit(0))
- self.assertEqual("RAND(0)", col.sql())
- no_col = SF.rand()
- self.assertEqual("RAND()", no_col.sql())
-
- def test_randn(self):
- col_str = SF.randn(0)
- self.assertEqual("RANDN(0)", col_str.sql())
- col = SF.randn(0)
- self.assertEqual("RANDN(0)", col.sql())
- no_col = SF.randn()
- self.assertEqual("RANDN()", no_col.sql())
-
- def test_round(self):
- col_str = SF.round("cola", 0)
- self.assertEqual("ROUND(cola, 0)", col_str.sql())
- col = SF.round(SF.col("cola"), 0)
- self.assertEqual("ROUND(cola, 0)", col.sql())
- col_no_scale = SF.round("cola")
- self.assertEqual("ROUND(cola)", col_no_scale.sql())
-
- def test_bround(self):
- col_str = SF.bround("cola", 0)
- self.assertEqual("BROUND(cola, 0)", col_str.sql())
- col = SF.bround(SF.col("cola"), 0)
- self.assertEqual("BROUND(cola, 0)", col.sql())
- col_no_scale = SF.bround("cola")
- self.assertEqual("BROUND(cola)", col_no_scale.sql())
-
- def test_shiftleft(self):
- col_str = SF.shiftleft("cola", 1)
- self.assertEqual("SHIFTLEFT(cola, 1)", col_str.sql())
- col = SF.shiftleft(SF.col("cola"), 1)
- self.assertEqual("SHIFTLEFT(cola, 1)", col.sql())
- col_legacy = SF.shiftLeft(SF.col("cola"), 1)
- self.assertEqual("SHIFTLEFT(cola, 1)", col_legacy.sql())
-
- def test_shiftright(self):
- col_str = SF.shiftright("cola", 1)
- self.assertEqual("SHIFTRIGHT(cola, 1)", col_str.sql())
- col = SF.shiftright(SF.col("cola"), 1)
- self.assertEqual("SHIFTRIGHT(cola, 1)", col.sql())
- col_legacy = SF.shiftRight(SF.col("cola"), 1)
- self.assertEqual("SHIFTRIGHT(cola, 1)", col_legacy.sql())
-
- def test_shiftrightunsigned(self):
- col_str = SF.shiftrightunsigned("cola", 1)
- self.assertEqual("SHIFTRIGHTUNSIGNED(cola, 1)", col_str.sql())
- col = SF.shiftrightunsigned(SF.col("cola"), 1)
- self.assertEqual("SHIFTRIGHTUNSIGNED(cola, 1)", col.sql())
- col_legacy = SF.shiftRightUnsigned(SF.col("cola"), 1)
- self.assertEqual("SHIFTRIGHTUNSIGNED(cola, 1)", col_legacy.sql())
-
- def test_expr(self):
- col_str = SF.expr("LENGTH(name)")
- self.assertEqual("LENGTH(name)", col_str.sql())
-
- def test_struct(self):
- col_str = SF.struct("cola", "colb", "colc")
- self.assertEqual("STRUCT(cola, colb, colc)", col_str.sql())
- col = SF.struct(SF.col("cola"), SF.col("colb"), SF.col("colc"))
- self.assertEqual("STRUCT(cola, colb, colc)", col.sql())
- col_single = SF.struct("cola")
- self.assertEqual("STRUCT(cola)", col_single.sql())
- col_list = SF.struct(["cola", "colb", "colc"])
- self.assertEqual("STRUCT(cola, colb, colc)", col_list.sql())
-
- def test_greatest(self):
- single_str = SF.greatest("cola")
- self.assertEqual("GREATEST(cola)", single_str.sql())
- single_col = SF.greatest(SF.col("cola"))
- self.assertEqual("GREATEST(cola)", single_col.sql())
- multiple_mix = SF.greatest("col1", "col2", SF.col("col3"), SF.col("col4"))
- self.assertEqual("GREATEST(col1, col2, col3, col4)", multiple_mix.sql())
-
- def test_least(self):
- single_str = SF.least("cola")
- self.assertEqual("LEAST(cola)", single_str.sql())
- single_col = SF.least(SF.col("cola"))
- self.assertEqual("LEAST(cola)", single_col.sql())
- multiple_mix = SF.least("col1", "col2", SF.col("col3"), SF.col("col4"))
- self.assertEqual("LEAST(col1, col2, col3, col4)", multiple_mix.sql())
-
- def test_when(self):
- col_simple = SF.when(SF.col("cola") == 2, 1)
- self.assertEqual("CASE WHEN cola = 2 THEN 1 END", col_simple.sql())
- col_complex = SF.when(SF.col("cola") == 2, SF.col("colb") + 2)
- self.assertEqual("CASE WHEN cola = 2 THEN colb + 2 END", col_complex.sql())
-
- def test_conv(self):
- col_str = SF.conv("cola", 2, 16)
- self.assertEqual("CONV(cola, 2, 16)", col_str.sql())
- col = SF.conv(SF.col("cola"), 2, 16)
- self.assertEqual("CONV(cola, 2, 16)", col.sql())
-
- def test_factorial(self):
- col_str = SF.factorial("cola")
- self.assertEqual("FACTORIAL(cola)", col_str.sql())
- col = SF.factorial(SF.col("cola"))
- self.assertEqual("FACTORIAL(cola)", col.sql())
-
- def test_lag(self):
- col_str = SF.lag("cola", 3, "colc")
- self.assertEqual("LAG(cola, 3, colc)", col_str.sql())
- col = SF.lag(SF.col("cola"), 3, "colc")
- self.assertEqual("LAG(cola, 3, colc)", col.sql())
- col_no_default = SF.lag("cola", 3)
- self.assertEqual("LAG(cola, 3)", col_no_default.sql())
- col_no_offset = SF.lag("cola")
- self.assertEqual("LAG(cola)", col_no_offset.sql())
-
- def test_lead(self):
- col_str = SF.lead("cola", 3, "colc")
- self.assertEqual("LEAD(cola, 3, colc)", col_str.sql())
- col = SF.lead(SF.col("cola"), 3, "colc")
- self.assertEqual("LEAD(cola, 3, colc)", col.sql())
- col_no_default = SF.lead("cola", 3)
- self.assertEqual("LEAD(cola, 3)", col_no_default.sql())
- col_no_offset = SF.lead("cola")
- self.assertEqual("LEAD(cola)", col_no_offset.sql())
-
- def test_nth_value(self):
- col_str = SF.nth_value("cola", 3)
- self.assertEqual("NTH_VALUE(cola, 3)", col_str.sql())
- col = SF.nth_value(SF.col("cola"), 3)
- self.assertEqual("NTH_VALUE(cola, 3)", col.sql())
- col_no_offset = SF.nth_value("cola")
- self.assertEqual("NTH_VALUE(cola)", col_no_offset.sql())
-
- self.assertEqual(
- "NTH_VALUE(cola) IGNORE NULLS", SF.nth_value("cola", ignoreNulls=True).sql()
- )
-
- def test_ntile(self):
- col = SF.ntile(2)
- self.assertEqual("NTILE(2)", col.sql())
-
- def test_current_date(self):
- col = SF.current_date()
- self.assertEqual("CURRENT_DATE", col.sql())
-
- def test_current_timestamp(self):
- col = SF.current_timestamp()
- self.assertEqual("CURRENT_TIMESTAMP()", col.sql())
-
- def test_date_format(self):
- col_str = SF.date_format("cola", "MM/dd/yyy")
- self.assertEqual("DATE_FORMAT(cola, 'MM/dd/yyy')", col_str.sql())
- col = SF.date_format(SF.col("cola"), "MM/dd/yyy")
- self.assertEqual("DATE_FORMAT(cola, 'MM/dd/yyy')", col.sql())
-
- def test_year(self):
- col_str = SF.year("cola")
- self.assertEqual("YEAR(cola)", col_str.sql())
- col = SF.year(SF.col("cola"))
- self.assertEqual("YEAR(cola)", col.sql())
-
- def test_quarter(self):
- col_str = SF.quarter("cola")
- self.assertEqual("QUARTER(cola)", col_str.sql())
- col = SF.quarter(SF.col("cola"))
- self.assertEqual("QUARTER(cola)", col.sql())
-
- def test_month(self):
- col_str = SF.month("cola")
- self.assertEqual("MONTH(cola)", col_str.sql())
- col = SF.month(SF.col("cola"))
- self.assertEqual("MONTH(cola)", col.sql())
-
- def test_dayofweek(self):
- col_str = SF.dayofweek("cola")
- self.assertEqual("DAYOFWEEK(cola)", col_str.sql())
- col = SF.dayofweek(SF.col("cola"))
- self.assertEqual("DAYOFWEEK(cola)", col.sql())
-
- def test_dayofmonth(self):
- col_str = SF.dayofmonth("cola")
- self.assertEqual("DAYOFMONTH(cola)", col_str.sql())
- col = SF.dayofmonth(SF.col("cola"))
- self.assertEqual("DAYOFMONTH(cola)", col.sql())
-
- def test_dayofyear(self):
- col_str = SF.dayofyear("cola")
- self.assertEqual("DAYOFYEAR(cola)", col_str.sql())
- col = SF.dayofyear(SF.col("cola"))
- self.assertEqual("DAYOFYEAR(cola)", col.sql())
-
- def test_hour(self):
- col_str = SF.hour("cola")
- self.assertEqual("HOUR(cola)", col_str.sql())
- col = SF.hour(SF.col("cola"))
- self.assertEqual("HOUR(cola)", col.sql())
-
- def test_minute(self):
- col_str = SF.minute("cola")
- self.assertEqual("MINUTE(cola)", col_str.sql())
- col = SF.minute(SF.col("cola"))
- self.assertEqual("MINUTE(cola)", col.sql())
-
- def test_second(self):
- col_str = SF.second("cola")
- self.assertEqual("SECOND(cola)", col_str.sql())
- col = SF.second(SF.col("cola"))
- self.assertEqual("SECOND(cola)", col.sql())
-
- def test_weekofyear(self):
- col_str = SF.weekofyear("cola")
- self.assertEqual("WEEKOFYEAR(cola)", col_str.sql())
- col = SF.weekofyear(SF.col("cola"))
- self.assertEqual("WEEKOFYEAR(cola)", col.sql())
-
- def test_make_date(self):
- col_str = SF.make_date("cola", "colb", "colc")
- self.assertEqual("MAKE_DATE(cola, colb, colc)", col_str.sql())
- col = SF.make_date(SF.col("cola"), SF.col("colb"), "colc")
- self.assertEqual("MAKE_DATE(cola, colb, colc)", col.sql())
-
- def test_date_add(self):
- col_str = SF.date_add("cola", 2)
- self.assertEqual("DATE_ADD(cola, 2)", col_str.sql())
- col = SF.date_add(SF.col("cola"), 2)
- self.assertEqual("DATE_ADD(cola, 2)", col.sql())
- col_col_for_add = SF.date_add("cola", "colb")
- self.assertEqual("DATE_ADD(cola, colb)", col_col_for_add.sql())
- current_date_add = SF.date_add(SF.current_date(), 5)
- self.assertEqual("DATE_ADD(CURRENT_DATE, 5)", current_date_add.sql())
- self.assertEqual("DATEADD(DAY, 5, CURRENT_DATE)", current_date_add.sql(dialect="snowflake"))
-
- def test_date_sub(self):
- col_str = SF.date_sub("cola", 2)
- self.assertEqual("DATE_ADD(cola, -2)", col_str.sql())
- col = SF.date_sub(SF.col("cola"), 2)
- self.assertEqual("DATE_ADD(cola, -2)", col.sql())
- col_col_for_add = SF.date_sub("cola", "colb")
- self.assertEqual("DATE_ADD(cola, colb * -1)", col_col_for_add.sql())
-
- def test_date_diff(self):
- col_str = SF.date_diff("cola", "colb")
- self.assertEqual("DATEDIFF(cola, colb)", col_str.sql())
- col = SF.date_diff(SF.col("cola"), SF.col("colb"))
- self.assertEqual("DATEDIFF(cola, colb)", col.sql())
-
- def test_add_months(self):
- col_str = SF.add_months("cola", 2)
- self.assertEqual("ADD_MONTHS(cola, 2)", col_str.sql())
- col = SF.add_months(SF.col("cola"), 2)
- self.assertEqual("ADD_MONTHS(cola, 2)", col.sql())
- col_col_for_add = SF.add_months("cola", "colb")
- self.assertEqual("ADD_MONTHS(cola, colb)", col_col_for_add.sql())
-
- def test_months_between(self):
- col_str = SF.months_between("cola", "colb")
- self.assertEqual("MONTHS_BETWEEN(cola, colb)", col_str.sql())
- col = SF.months_between(SF.col("cola"), SF.col("colb"))
- self.assertEqual("MONTHS_BETWEEN(cola, colb)", col.sql())
- col_round_off = SF.months_between("cola", "colb", True)
- self.assertEqual("MONTHS_BETWEEN(cola, colb, TRUE)", col_round_off.sql())
-
- def test_to_date(self):
- col_str = SF.to_date("cola")
- self.assertEqual("TO_DATE(cola)", col_str.sql())
- col = SF.to_date(SF.col("cola"))
- self.assertEqual("TO_DATE(cola)", col.sql())
- col_with_format = SF.to_date("cola", "yy-MM-dd")
- self.assertEqual("TO_DATE(cola, 'yy-MM-dd')", col_with_format.sql())
-
- def test_to_timestamp(self):
- col_str = SF.to_timestamp("cola")
- self.assertEqual("CAST(cola AS TIMESTAMP)", col_str.sql())
- col = SF.to_timestamp(SF.col("cola"))
- self.assertEqual("CAST(cola AS TIMESTAMP)", col.sql())
- col_with_format = SF.to_timestamp("cola", "yyyy-MM-dd")
- self.assertEqual("TO_TIMESTAMP(cola, 'yyyy-MM-dd')", col_with_format.sql())
-
- def test_trunc(self):
- col_str = SF.trunc("cola", "year")
- self.assertEqual("TRUNC(cola, 'YEAR')", col_str.sql())
- col = SF.trunc(SF.col("cola"), "year")
- self.assertEqual("TRUNC(cola, 'YEAR')", col.sql())
-
- def test_date_trunc(self):
- col_str = SF.date_trunc("year", "cola")
- self.assertEqual("DATE_TRUNC('YEAR', cola)", col_str.sql())
- col = SF.date_trunc("YEAR", SF.col("cola"))
- self.assertEqual("DATE_TRUNC('YEAR', cola)", col.sql())
-
- def test_next_day(self):
- col_str = SF.next_day("cola", "Mon")
- self.assertEqual("NEXT_DAY(cola, 'Mon')", col_str.sql())
- col = SF.next_day(SF.col("cola"), "Mon")
- self.assertEqual("NEXT_DAY(cola, 'Mon')", col.sql())
-
- def test_last_day(self):
- col_str = SF.last_day("cola")
- self.assertEqual("LAST_DAY(cola)", col_str.sql())
- col = SF.last_day(SF.col("cola"))
- self.assertEqual("LAST_DAY(cola)", col.sql())
-
- def test_from_unixtime(self):
- col_str = SF.from_unixtime("cola")
- self.assertEqual("FROM_UNIXTIME(cola)", col_str.sql())
- col = SF.from_unixtime(SF.col("cola"))
- self.assertEqual("FROM_UNIXTIME(cola)", col.sql())
- col_format = SF.from_unixtime("cola", "yyyy-MM-dd HH:mm")
- self.assertEqual("FROM_UNIXTIME(cola, 'yyyy-MM-dd HH:mm')", col_format.sql())
-
- def test_unix_timestamp(self):
- col_str = SF.unix_timestamp("cola")
- self.assertEqual("UNIX_TIMESTAMP(cola)", col_str.sql())
- col = SF.unix_timestamp(SF.col("cola"))
- self.assertEqual("UNIX_TIMESTAMP(cola)", col.sql())
- col_format = SF.unix_timestamp("cola", "yyyy-MM-dd HH:mm")
- self.assertEqual("UNIX_TIMESTAMP(cola, 'yyyy-MM-dd HH:mm')", col_format.sql())
- col_current = SF.unix_timestamp()
- self.assertEqual("UNIX_TIMESTAMP()", col_current.sql())
-
- def test_from_utc_timestamp(self):
- col_str = SF.from_utc_timestamp("cola", "PST")
- self.assertEqual("FROM_UTC_TIMESTAMP(cola, 'PST')", col_str.sql())
- col = SF.from_utc_timestamp(SF.col("cola"), "PST")
- self.assertEqual("FROM_UTC_TIMESTAMP(cola, 'PST')", col.sql())
- col_col = SF.from_utc_timestamp("cola", SF.col("colb"))
- self.assertEqual("FROM_UTC_TIMESTAMP(cola, colb)", col_col.sql())
-
- def test_to_utc_timestamp(self):
- col_str = SF.to_utc_timestamp("cola", "PST")
- self.assertEqual("TO_UTC_TIMESTAMP(cola, 'PST')", col_str.sql())
- col = SF.to_utc_timestamp(SF.col("cola"), "PST")
- self.assertEqual("TO_UTC_TIMESTAMP(cola, 'PST')", col.sql())
- col_col = SF.to_utc_timestamp("cola", SF.col("colb"))
- self.assertEqual("TO_UTC_TIMESTAMP(cola, colb)", col_col.sql())
-
- def test_timestamp_seconds(self):
- col_str = SF.timestamp_seconds("cola")
- self.assertEqual("TIMESTAMP_SECONDS(cola)", col_str.sql())
- col = SF.timestamp_seconds(SF.col("cola"))
- self.assertEqual("TIMESTAMP_SECONDS(cola)", col.sql())
-
- def test_window(self):
- col_str = SF.window("cola", "10 minutes")
- self.assertEqual("WINDOW(cola, '10 minutes')", col_str.sql())
- col = SF.window(SF.col("cola"), "10 minutes")
- self.assertEqual("WINDOW(cola, '10 minutes')", col.sql())
- col_all_values = SF.window("cola", "2 minutes 30 seconds", "30 seconds", "15 seconds")
- self.assertEqual(
- "WINDOW(cola, '2 minutes 30 seconds', '30 seconds', '15 seconds')", col_all_values.sql()
- )
- col_no_start_time = SF.window("cola", "2 minutes 30 seconds", "30 seconds")
- self.assertEqual(
- "WINDOW(cola, '2 minutes 30 seconds', '30 seconds')", col_no_start_time.sql()
- )
- col_no_slide = SF.window("cola", "2 minutes 30 seconds", startTime="15 seconds")
- self.assertEqual(
- "WINDOW(cola, '2 minutes 30 seconds', '2 minutes 30 seconds', '15 seconds')",
- col_no_slide.sql(),
- )
-
- def test_session_window(self):
- col_str = SF.session_window("cola", "5 seconds")
- self.assertEqual("SESSION_WINDOW(cola, '5 seconds')", col_str.sql())
- col = SF.session_window(SF.col("cola"), SF.lit("5 seconds"))
- self.assertEqual("SESSION_WINDOW(cola, '5 seconds')", col.sql())
-
- def test_crc32(self):
- col_str = SF.crc32("Spark")
- self.assertEqual("CRC32('Spark')", col_str.sql())
- col = SF.crc32(SF.col("cola"))
- self.assertEqual("CRC32(cola)", col.sql())
-
- def test_md5(self):
- col_str = SF.md5("Spark")
- self.assertEqual("MD5('Spark')", col_str.sql())
- col = SF.md5(SF.col("cola"))
- self.assertEqual("MD5(cola)", col.sql())
-
- def test_sha1(self):
- col_str = SF.sha1("Spark")
- self.assertEqual("SHA('Spark')", col_str.sql())
- col = SF.sha1(SF.col("cola"))
- self.assertEqual("SHA(cola)", col.sql())
-
- def test_sha2(self):
- col_str = SF.sha2("Spark", 256)
- self.assertEqual("SHA2('Spark', 256)", col_str.sql())
- col = SF.sha2(SF.col("cola"), 256)
- self.assertEqual("SHA2(cola, 256)", col.sql())
-
- def test_hash(self):
- col_str = SF.hash("cola", "colb", "colc")
- self.assertEqual("HASH(cola, colb, colc)", col_str.sql())
- col = SF.hash(SF.col("cola"), SF.col("colb"), SF.col("colc"))
- self.assertEqual("HASH(cola, colb, colc)", col.sql())
-
- def test_xxhash64(self):
- col_str = SF.xxhash64("cola", "colb", "colc")
- self.assertEqual("XXHASH64(cola, colb, colc)", col_str.sql())
- col = SF.xxhash64(SF.col("cola"), SF.col("colb"), SF.col("colc"))
- self.assertEqual("XXHASH64(cola, colb, colc)", col.sql())
-
- def test_assert_true(self):
- col = SF.assert_true(SF.col("cola") < SF.col("colb"))
- self.assertEqual("ASSERT_TRUE(cola < colb)", col.sql())
- col_error_msg_col = SF.assert_true(SF.col("cola") < SF.col("colb"), SF.col("colc"))
- self.assertEqual("ASSERT_TRUE(cola < colb, colc)", col_error_msg_col.sql())
- col_error_msg_lit = SF.assert_true(SF.col("cola") < SF.col("colb"), "error")
- self.assertEqual("ASSERT_TRUE(cola < colb, 'error')", col_error_msg_lit.sql())
-
- def test_raise_error(self):
- col_str = SF.raise_error("custom error message")
- self.assertEqual("RAISE_ERROR('custom error message')", col_str.sql())
- col = SF.raise_error(SF.col("cola"))
- self.assertEqual("RAISE_ERROR(cola)", col.sql())
-
- def test_upper(self):
- col_str = SF.upper("cola")
- self.assertEqual("UPPER(cola)", col_str.sql())
- col = SF.upper(SF.col("cola"))
- self.assertEqual("UPPER(cola)", col.sql())
-
- def test_lower(self):
- col_str = SF.lower("cola")
- self.assertEqual("LOWER(cola)", col_str.sql())
- col = SF.lower(SF.col("cola"))
- self.assertEqual("LOWER(cola)", col.sql())
-
- def test_ascii(self):
- col_str = SF.ascii(SF.lit(2))
- self.assertEqual("ASCII(2)", col_str.sql())
- col = SF.ascii(SF.col("cola"))
- self.assertEqual("ASCII(cola)", col.sql())
-
- def test_base64(self):
- col_str = SF.base64(SF.lit(2))
- self.assertEqual("BASE64(2)", col_str.sql())
- col = SF.base64(SF.col("cola"))
- self.assertEqual("BASE64(cola)", col.sql())
-
- def test_unbase64(self):
- col_str = SF.unbase64(SF.lit(2))
- self.assertEqual("UNBASE64(2)", col_str.sql())
- col = SF.unbase64(SF.col("cola"))
- self.assertEqual("UNBASE64(cola)", col.sql())
-
- def test_ltrim(self):
- col_str = SF.ltrim(SF.lit("Spark"))
- self.assertEqual("LTRIM('Spark')", col_str.sql())
- col = SF.ltrim(SF.col("cola"))
- self.assertEqual("LTRIM(cola)", col.sql())
-
- def test_rtrim(self):
- col_str = SF.rtrim(SF.lit("Spark"))
- self.assertEqual("RTRIM('Spark')", col_str.sql())
- col = SF.rtrim(SF.col("cola"))
- self.assertEqual("RTRIM(cola)", col.sql())
-
- def test_trim(self):
- col_str = SF.trim(SF.lit("Spark"))
- self.assertEqual("TRIM('Spark')", col_str.sql())
- col = SF.trim(SF.col("cola"))
- self.assertEqual("TRIM(cola)", col.sql())
-
- def test_concat_ws(self):
- col_str = SF.concat_ws("-", "cola", "colb")
- self.assertEqual("CONCAT_WS('-', cola, colb)", col_str.sql())
- col = SF.concat_ws("-", SF.col("cola"), SF.col("colb"))
- self.assertEqual("CONCAT_WS('-', cola, colb)", col.sql())
-
- def test_decode(self):
- col_str = SF.decode("cola", "US-ASCII")
- self.assertEqual("DECODE(cola, 'US-ASCII')", col_str.sql())
- col = SF.decode(SF.col("cola"), "US-ASCII")
- self.assertEqual("DECODE(cola, 'US-ASCII')", col.sql())
-
- def test_encode(self):
- col_str = SF.encode("cola", "US-ASCII")
- self.assertEqual("ENCODE(cola, 'US-ASCII')", col_str.sql())
- col = SF.encode(SF.col("cola"), "US-ASCII")
- self.assertEqual("ENCODE(cola, 'US-ASCII')", col.sql())
-
- def test_format_number(self):
- col_str = SF.format_number("cola", 4)
- self.assertEqual("FORMAT_NUMBER(cola, 4)", col_str.sql())
- col = SF.format_number(SF.col("cola"), 4)
- self.assertEqual("FORMAT_NUMBER(cola, 4)", col.sql())
-
- def test_format_string(self):
- col_str = SF.format_string("%d %s", "cola", "colb", "colc")
- self.assertEqual("FORMAT_STRING('%d %s', cola, colb, colc)", col_str.sql())
- col = SF.format_string("%d %s", SF.col("cola"), SF.col("colb"), SF.col("colc"))
- self.assertEqual("FORMAT_STRING('%d %s', cola, colb, colc)", col.sql())
-
- def test_instr(self):
- col_str = SF.instr("cola", "test")
- self.assertEqual("INSTR(cola, 'test')", col_str.sql())
- col = SF.instr(SF.col("cola"), "test")
- self.assertEqual("INSTR(cola, 'test')", col.sql())
-
- def test_overlay(self):
- col_str = SF.overlay("cola", "colb", 3, 7)
- self.assertEqual("OVERLAY(cola, colb, 3, 7)", col_str.sql())
- col = SF.overlay(SF.col("cola"), SF.col("colb"), SF.lit(3), SF.lit(7))
- self.assertEqual("OVERLAY(cola, colb, 3, 7)", col.sql())
- col_no_length = SF.overlay("cola", "colb", 3)
- self.assertEqual("OVERLAY(cola, colb, 3)", col_no_length.sql())
-
- def test_sentences(self):
- col_str = SF.sentences("cola", SF.lit("en"), SF.lit("US"))
- self.assertEqual("SENTENCES(cola, 'en', 'US')", col_str.sql())
- col = SF.sentences(SF.col("cola"), SF.lit("en"), SF.lit("US"))
- self.assertEqual("SENTENCES(cola, 'en', 'US')", col.sql())
- col_no_country = SF.sentences("cola", SF.lit("en"))
- self.assertEqual("SENTENCES(cola, 'en')", col_no_country.sql())
- col_no_lang = SF.sentences(SF.col("cola"), country=SF.lit("US"))
- self.assertEqual("SENTENCES(cola, 'en', 'US')", col_no_lang.sql())
- col_defaults = SF.sentences(SF.col("cola"))
- self.assertEqual("SENTENCES(cola)", col_defaults.sql())
-
- def test_substring(self):
- col_str = SF.substring("cola", 2, 3)
- self.assertEqual("SUBSTRING(cola, 2, 3)", col_str.sql())
- col = SF.substring(SF.col("cola"), 2, 3)
- self.assertEqual("SUBSTRING(cola, 2, 3)", col.sql())
-
- def test_substring_index(self):
- col_str = SF.substring_index("cola", ".", 2)
- self.assertEqual("SUBSTRING_INDEX(cola, '.', 2)", col_str.sql())
- col = SF.substring_index(SF.col("cola"), ".", 2)
- self.assertEqual("SUBSTRING_INDEX(cola, '.', 2)", col.sql())
-
- def test_levenshtein(self):
- col_str = SF.levenshtein("cola", "colb")
- self.assertEqual("LEVENSHTEIN(cola, colb)", col_str.sql())
- col = SF.levenshtein(SF.col("cola"), SF.col("colb"))
- self.assertEqual("LEVENSHTEIN(cola, colb)", col.sql())
-
- def test_locate(self):
- col_str = SF.locate("test", "cola", 3)
- self.assertEqual("LOCATE('test', cola, 3)", col_str.sql())
- col = SF.locate("test", SF.col("cola"), 3)
- self.assertEqual("LOCATE('test', cola, 3)", col.sql())
- col_no_pos = SF.locate("test", "cola")
- self.assertEqual("LOCATE('test', cola)", col_no_pos.sql())
-
- def test_lpad(self):
- col_str = SF.lpad("cola", 3, "#")
- self.assertEqual("LPAD(cola, 3, '#')", col_str.sql())
- col = SF.lpad(SF.col("cola"), 3, "#")
- self.assertEqual("LPAD(cola, 3, '#')", col.sql())
-
- def test_rpad(self):
- col_str = SF.rpad("cola", 3, "#")
- self.assertEqual("RPAD(cola, 3, '#')", col_str.sql())
- col = SF.rpad(SF.col("cola"), 3, "#")
- self.assertEqual("RPAD(cola, 3, '#')", col.sql())
-
- def test_repeat(self):
- col_str = SF.repeat("cola", 3)
- self.assertEqual("REPEAT(cola, 3)", col_str.sql())
- col = SF.repeat(SF.col("cola"), 3)
- self.assertEqual("REPEAT(cola, 3)", col.sql())
-
- def test_split(self):
- col_str = SF.split("cola", "[ABC]", 3)
- self.assertEqual("SPLIT(cola, '[ABC]', 3)", col_str.sql())
- col = SF.split(SF.col("cola"), "[ABC]", 3)
- self.assertEqual("SPLIT(cola, '[ABC]', 3)", col.sql())
- col_no_limit = SF.split("cola", "[ABC]")
- self.assertEqual("SPLIT(cola, '[ABC]')", col_no_limit.sql())
-
- def test_regexp_extract(self):
- col_str = SF.regexp_extract("cola", r"(\d+)-(\d+)", 1)
- self.assertEqual("REGEXP_EXTRACT(cola, '(\\\\d+)-(\\\\d+)', 1)", col_str.sql())
- col = SF.regexp_extract(SF.col("cola"), r"(\d+)-(\d+)", 1)
- self.assertEqual("REGEXP_EXTRACT(cola, '(\\\\d+)-(\\\\d+)', 1)", col.sql())
- col_no_idx = SF.regexp_extract(SF.col("cola"), r"(\d+)-(\d+)")
- self.assertEqual("REGEXP_EXTRACT(cola, '(\\\\d+)-(\\\\d+)')", col_no_idx.sql())
-
- def test_regexp_replace(self):
- col_str = SF.regexp_replace("cola", r"(\d+)", "--")
- self.assertEqual("REGEXP_REPLACE(cola, '(\\\\d+)', '--')", col_str.sql())
- col = SF.regexp_replace(SF.col("cola"), r"(\d+)", "--")
- self.assertEqual("REGEXP_REPLACE(cola, '(\\\\d+)', '--')", col.sql())
-
- def test_initcap(self):
- col_str = SF.initcap("cola")
- self.assertEqual("INITCAP(cola)", col_str.sql())
- col = SF.initcap(SF.col("cola"))
- self.assertEqual("INITCAP(cola)", col.sql())
-
- def test_soundex(self):
- col_str = SF.soundex("cola")
- self.assertEqual("SOUNDEX(cola)", col_str.sql())
- col = SF.soundex(SF.col("cola"))
- self.assertEqual("SOUNDEX(cola)", col.sql())
-
- def test_bin(self):
- col_str = SF.bin("cola")
- self.assertEqual("BIN(cola)", col_str.sql())
- col = SF.bin(SF.col("cola"))
- self.assertEqual("BIN(cola)", col.sql())
-
- def test_hex(self):
- col_str = SF.hex("cola")
- self.assertEqual("HEX(cola)", col_str.sql())
- col = SF.hex(SF.col("cola"))
- self.assertEqual("HEX(cola)", col.sql())
-
- def test_unhex(self):
- col_str = SF.unhex("cola")
- self.assertEqual("UNHEX(cola)", col_str.sql())
- col = SF.unhex(SF.col("cola"))
- self.assertEqual("UNHEX(cola)", col.sql())
-
- def test_length(self):
- col_str = SF.length("cola")
- self.assertEqual("LENGTH(cola)", col_str.sql())
- col = SF.length(SF.col("cola"))
- self.assertEqual("LENGTH(cola)", col.sql())
-
- def test_octet_length(self):
- col_str = SF.octet_length("cola")
- self.assertEqual("OCTET_LENGTH(cola)", col_str.sql())
- col = SF.octet_length(SF.col("cola"))
- self.assertEqual("OCTET_LENGTH(cola)", col.sql())
-
- def test_bit_length(self):
- col_str = SF.bit_length("cola")
- self.assertEqual("BIT_LENGTH(cola)", col_str.sql())
- col = SF.bit_length(SF.col("cola"))
- self.assertEqual("BIT_LENGTH(cola)", col.sql())
-
- def test_translate(self):
- col_str = SF.translate("cola", "abc", "xyz")
- self.assertEqual("TRANSLATE(cola, 'abc', 'xyz')", col_str.sql())
- col = SF.translate(SF.col("cola"), "abc", "xyz")
- self.assertEqual("TRANSLATE(cola, 'abc', 'xyz')", col.sql())
-
- def test_array(self):
- col_str = SF.array("cola", "colb")
- self.assertEqual("ARRAY(cola, colb)", col_str.sql())
- col = SF.array(SF.col("cola"), SF.col("colb"))
- self.assertEqual("ARRAY(cola, colb)", col.sql())
- col_array = SF.array(["cola", "colb"])
- self.assertEqual("ARRAY(cola, colb)", col_array.sql())
-
- def test_create_map(self):
- col_str = SF.create_map("keya", "valuea", "keyb", "valueb")
- self.assertEqual("MAP(keya, valuea, keyb, valueb)", col_str.sql())
- col = SF.create_map(SF.col("keya"), SF.col("valuea"), SF.col("keyb"), SF.col("valueb"))
- self.assertEqual("MAP(keya, valuea, keyb, valueb)", col.sql())
- col_array = SF.create_map(["keya", "valuea", "keyb", "valueb"])
- self.assertEqual("MAP(keya, valuea, keyb, valueb)", col_array.sql())
-
- def test_map_from_arrays(self):
- col_str = SF.map_from_arrays("cola", "colb")
- self.assertEqual("MAP_FROM_ARRAYS(cola, colb)", col_str.sql())
- col = SF.map_from_arrays(SF.col("cola"), SF.col("colb"))
- self.assertEqual("MAP_FROM_ARRAYS(cola, colb)", col.sql())
-
- def test_array_contains(self):
- col_str = SF.array_contains("cola", "test")
- self.assertEqual("ARRAY_CONTAINS(cola, 'test')", col_str.sql())
- col = SF.array_contains(SF.col("cola"), "test")
- self.assertEqual("ARRAY_CONTAINS(cola, 'test')", col.sql())
- col_as_value = SF.array_contains("cola", SF.col("colb"))
- self.assertEqual("ARRAY_CONTAINS(cola, colb)", col_as_value.sql())
-
- def test_arrays_overlap(self):
- col_str = SF.arrays_overlap("cola", "colb")
- self.assertEqual("ARRAYS_OVERLAP(cola, colb)", col_str.sql())
- col = SF.arrays_overlap(SF.col("cola"), SF.col("colb"))
- self.assertEqual("ARRAYS_OVERLAP(cola, colb)", col.sql())
-
- def test_slice(self):
- col_str = SF.slice("cola", SF.col("colb"), SF.col("colc"))
- self.assertEqual("SLICE(cola, colb, colc)", col_str.sql())
- col = SF.slice(SF.col("cola"), SF.col("colb"), SF.col("colc"))
- self.assertEqual("SLICE(cola, colb, colc)", col.sql())
- col_ints = SF.slice("cola", 1, 10)
- self.assertEqual("SLICE(cola, 1, 10)", col_ints.sql())
-
- def test_array_join(self):
- col_str = SF.array_join("cola", "-", "NULL_REPLACEMENT")
- self.assertEqual("ARRAY_JOIN(cola, '-', 'NULL_REPLACEMENT')", col_str.sql())
- col = SF.array_join(SF.col("cola"), "-", "NULL_REPLACEMENT")
- self.assertEqual("ARRAY_JOIN(cola, '-', 'NULL_REPLACEMENT')", col.sql())
- col_no_replacement = SF.array_join("cola", "-")
- self.assertEqual("ARRAY_JOIN(cola, '-')", col_no_replacement.sql())
-
- def test_concat(self):
- col_str = SF.concat("cola", "colb")
- self.assertEqual("CONCAT(cola, colb)", col_str.sql())
- col = SF.concat(SF.col("cola"), SF.col("colb"))
- self.assertEqual("CONCAT(cola, colb)", col.sql())
- col_single = SF.concat("cola")
- self.assertEqual("CONCAT(cola)", col_single.sql())
-
- def test_array_position(self):
- col_str = SF.array_position("cola", SF.col("colb"))
- self.assertEqual("ARRAY_POSITION(cola, colb)", col_str.sql())
- col = SF.array_position(SF.col("cola"), SF.col("colb"))
- self.assertEqual("ARRAY_POSITION(cola, colb)", col.sql())
- col_lit = SF.array_position("cola", "test")
- self.assertEqual("ARRAY_POSITION(cola, 'test')", col_lit)
-
- def test_element_at(self):
- col_str = SF.element_at("cola", SF.col("colb"))
- self.assertEqual("ELEMENT_AT(cola, colb)", col_str.sql())
- col = SF.element_at(SF.col("cola"), SF.col("colb"))
- self.assertEqual("ELEMENT_AT(cola, colb)", col.sql())
- col_lit = SF.element_at("cola", "test")
- self.assertEqual("ELEMENT_AT(cola, 'test')", col_lit)
-
- def test_array_remove(self):
- col_str = SF.array_remove("cola", SF.col("colb"))
- self.assertEqual("ARRAY_REMOVE(cola, colb)", col_str.sql())
- col = SF.array_remove(SF.col("cola"), SF.col("colb"))
- self.assertEqual("ARRAY_REMOVE(cola, colb)", col.sql())
- col_lit = SF.array_remove("cola", "test")
- self.assertEqual("ARRAY_REMOVE(cola, 'test')", col_lit)
-
- def test_array_distinct(self):
- col_str = SF.array_distinct("cola")
- self.assertEqual("ARRAY_DISTINCT(cola)", col_str.sql())
- col = SF.array_distinct(SF.col("cola"))
- self.assertEqual("ARRAY_DISTINCT(cola)", col.sql())
-
- def test_array_intersect(self):
- col_str = SF.array_intersect("cola", "colb")
- self.assertEqual("ARRAY_INTERSECT(cola, colb)", col_str.sql())
- col = SF.array_intersect(SF.col("cola"), SF.col("colb"))
- self.assertEqual("ARRAY_INTERSECT(cola, colb)", col.sql())
-
- def test_array_union(self):
- col_str = SF.array_union("cola", "colb")
- self.assertEqual("ARRAY_UNION(cola, colb)", col_str.sql())
- col = SF.array_union(SF.col("cola"), SF.col("colb"))
- self.assertEqual("ARRAY_UNION(cola, colb)", col.sql())
-
- def test_array_except(self):
- col_str = SF.array_except("cola", "colb")
- self.assertEqual("ARRAY_EXCEPT(cola, colb)", col_str.sql())
- col = SF.array_except(SF.col("cola"), SF.col("colb"))
- self.assertEqual("ARRAY_EXCEPT(cola, colb)", col.sql())
-
- def test_explode(self):
- col_str = SF.explode("cola")
- self.assertEqual("EXPLODE(cola)", col_str.sql())
- col = SF.explode(SF.col("cola"))
- self.assertEqual("EXPLODE(cola)", col.sql())
-
- def test_pos_explode(self):
- col_str = SF.posexplode("cola")
- self.assertEqual("POSEXPLODE(cola)", col_str.sql())
- col = SF.posexplode(SF.col("cola"))
- self.assertEqual("POSEXPLODE(cola)", col.sql())
-
- def test_explode_outer(self):
- col_str = SF.explode_outer("cola")
- self.assertEqual("EXPLODE_OUTER(cola)", col_str.sql())
- col = SF.explode_outer(SF.col("cola"))
- self.assertEqual("EXPLODE_OUTER(cola)", col.sql())
-
- def test_posexplode_outer(self):
- col_str = SF.posexplode_outer("cola")
- self.assertEqual("POSEXPLODE_OUTER(cola)", col_str.sql())
- col = SF.posexplode_outer(SF.col("cola"))
- self.assertEqual("POSEXPLODE_OUTER(cola)", col.sql())
-
- def test_get_json_object(self):
- col_str = SF.get_json_object("cola", "$.f1")
- self.assertEqual("GET_JSON_OBJECT(cola, '$.f1')", col_str.sql())
- col = SF.get_json_object(SF.col("cola"), "$.f1")
- self.assertEqual("GET_JSON_OBJECT(cola, '$.f1')", col.sql())
-
- def test_json_tuple(self):
- col_str = SF.json_tuple("cola", "f1", "f2")
- self.assertEqual("JSON_TUPLE(cola, 'f1', 'f2')", col_str.sql())
- col = SF.json_tuple(SF.col("cola"), "f1", "f2")
- self.assertEqual("JSON_TUPLE(cola, 'f1', 'f2')", col.sql())
-
- def test_from_json(self):
- col_str = SF.from_json("cola", "cola INT", dict(timestampFormat="dd/MM/yyyy"))
- self.assertEqual(
- "FROM_JSON(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql()
- )
- col = SF.from_json(SF.col("cola"), "cola INT", dict(timestampFormat="dd/MM/yyyy"))
- self.assertEqual(
- "FROM_JSON(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col.sql()
- )
- col_no_option = SF.from_json("cola", "cola INT")
- self.assertEqual("FROM_JSON(cola, 'cola INT')", col_no_option.sql())
-
- def test_to_json(self):
- col_str = SF.to_json("cola", dict(timestampFormat="dd/MM/yyyy"))
- self.assertEqual("TO_JSON(cola, MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql())
- col = SF.to_json(SF.col("cola"), dict(timestampFormat="dd/MM/yyyy"))
- self.assertEqual("TO_JSON(cola, MAP('timestampFormat', 'dd/MM/yyyy'))", col.sql())
- col_no_option = SF.to_json("cola")
- self.assertEqual("TO_JSON(cola)", col_no_option.sql())
-
- def test_schema_of_json(self):
- col_str = SF.schema_of_json("cola", dict(timestampFormat="dd/MM/yyyy"))
- self.assertEqual(
- "SCHEMA_OF_JSON(cola, MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql()
- )
- col = SF.schema_of_json(SF.col("cola"), dict(timestampFormat="dd/MM/yyyy"))
- self.assertEqual("SCHEMA_OF_JSON(cola, MAP('timestampFormat', 'dd/MM/yyyy'))", col.sql())
- col_no_option = SF.schema_of_json("cola")
- self.assertEqual("SCHEMA_OF_JSON(cola)", col_no_option.sql())
-
- def test_schema_of_csv(self):
- col_str = SF.schema_of_csv("cola", dict(timestampFormat="dd/MM/yyyy"))
- self.assertEqual("SCHEMA_OF_CSV(cola, MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql())
- col = SF.schema_of_csv(SF.col("cola"), dict(timestampFormat="dd/MM/yyyy"))
- self.assertEqual("SCHEMA_OF_CSV(cola, MAP('timestampFormat', 'dd/MM/yyyy'))", col.sql())
- col_no_option = SF.schema_of_csv("cola")
- self.assertEqual("SCHEMA_OF_CSV(cola)", col_no_option.sql())
-
- def test_to_csv(self):
- col_str = SF.to_csv("cola", dict(timestampFormat="dd/MM/yyyy"))
- self.assertEqual("TO_CSV(cola, MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql())
- col = SF.to_csv(SF.col("cola"), dict(timestampFormat="dd/MM/yyyy"))
- self.assertEqual("TO_CSV(cola, MAP('timestampFormat', 'dd/MM/yyyy'))", col.sql())
- col_no_option = SF.to_csv("cola")
- self.assertEqual("TO_CSV(cola)", col_no_option.sql())
-
- def test_size(self):
- col_str = SF.size("cola")
- self.assertEqual("SIZE(cola)", col_str.sql())
- col = SF.size(SF.col("cola"))
- self.assertEqual("SIZE(cola)", col.sql())
-
- def test_array_min(self):
- col_str = SF.array_min("cola")
- self.assertEqual("ARRAY_MIN(cola)", col_str.sql())
- col = SF.array_min(SF.col("cola"))
- self.assertEqual("ARRAY_MIN(cola)", col.sql())
-
- def test_array_max(self):
- col_str = SF.array_max("cola")
- self.assertEqual("ARRAY_MAX(cola)", col_str.sql())
- col = SF.array_max(SF.col("cola"))
- self.assertEqual("ARRAY_MAX(cola)", col.sql())
-
- def test_sort_array(self):
- col_str = SF.sort_array("cola", False)
- self.assertEqual("SORT_ARRAY(cola, FALSE)", col_str.sql())
- col = SF.sort_array(SF.col("cola"), False)
- self.assertEqual("SORT_ARRAY(cola, FALSE)", col.sql())
- col_no_sort = SF.sort_array("cola")
- self.assertEqual("SORT_ARRAY(cola)", col_no_sort.sql())
-
- def test_array_sort(self):
- col_str = SF.array_sort("cola")
- self.assertEqual("ARRAY_SORT(cola)", col_str.sql())
- col = SF.array_sort(SF.col("cola"))
- self.assertEqual("ARRAY_SORT(cola)", col.sql())
- col_comparator = SF.array_sort(
- "cola",
- lambda x, y: SF.when(x.isNull() | y.isNull(), SF.lit(0)).otherwise(
- SF.length(y) - SF.length(x)
- ),
- )
- self.assertEqual(
- "ARRAY_SORT(cola, (x, y) -> CASE WHEN x IS NULL OR y IS NULL THEN 0 ELSE LENGTH(y) - LENGTH(x) END)",
- col_comparator.sql(),
- )
-
- def test_reverse(self):
- col_str = SF.reverse("cola")
- self.assertEqual("REVERSE(cola)", col_str.sql())
- col = SF.reverse(SF.col("cola"))
- self.assertEqual("REVERSE(cola)", col.sql())
-
- def test_flatten(self):
- col_str = SF.flatten("cola")
- self.assertEqual("FLATTEN(cola)", col_str.sql())
- col = SF.flatten(SF.col("cola"))
- self.assertEqual("FLATTEN(cola)", col.sql())
-
- def test_map_keys(self):
- col_str = SF.map_keys("cola")
- self.assertEqual("MAP_KEYS(cola)", col_str.sql())
- col = SF.map_keys(SF.col("cola"))
- self.assertEqual("MAP_KEYS(cola)", col.sql())
-
- def test_map_values(self):
- col_str = SF.map_values("cola")
- self.assertEqual("MAP_VALUES(cola)", col_str.sql())
- col = SF.map_values(SF.col("cola"))
- self.assertEqual("MAP_VALUES(cola)", col.sql())
-
- def test_map_entries(self):
- col_str = SF.map_entries("cola")
- self.assertEqual("MAP_ENTRIES(cola)", col_str.sql())
- col = SF.map_entries(SF.col("cola"))
- self.assertEqual("MAP_ENTRIES(cola)", col.sql())
-
- def test_map_from_entries(self):
- col_str = SF.map_from_entries("cola")
- self.assertEqual("MAP_FROM_ENTRIES(cola)", col_str.sql())
- col = SF.map_from_entries(SF.col("cola"))
- self.assertEqual("MAP_FROM_ENTRIES(cola)", col.sql())
-
- def test_array_repeat(self):
- col_str = SF.array_repeat("cola", 2)
- self.assertEqual("ARRAY_REPEAT(cola, 2)", col_str.sql())
- col = SF.array_repeat(SF.col("cola"), 2)
- self.assertEqual("ARRAY_REPEAT(cola, 2)", col.sql())
-
- def test_array_zip(self):
- col_str = SF.array_zip("cola", "colb")
- self.assertEqual("ARRAY_ZIP(cola, colb)", col_str.sql())
- col = SF.array_zip(SF.col("cola"), SF.col("colb"))
- self.assertEqual("ARRAY_ZIP(cola, colb)", col.sql())
- col_single = SF.array_zip("cola")
- self.assertEqual("ARRAY_ZIP(cola)", col_single.sql())
-
- def test_map_concat(self):
- col_str = SF.map_concat("cola", "colb")
- self.assertEqual("MAP_CONCAT(cola, colb)", col_str.sql())
- col = SF.map_concat(SF.col("cola"), SF.col("colb"))
- self.assertEqual("MAP_CONCAT(cola, colb)", col.sql())
- col_single = SF.map_concat("cola")
- self.assertEqual("MAP_CONCAT(cola)", col_single.sql())
-
- def test_sequence(self):
- col_str = SF.sequence("cola", "colb", "colc")
- self.assertEqual("SEQUENCE(cola, colb, colc)", col_str.sql())
- col = SF.sequence(SF.col("cola"), SF.col("colb"), SF.col("colc"))
- self.assertEqual("SEQUENCE(cola, colb, colc)", col.sql())
- col_no_step = SF.sequence("cola", "colb")
- self.assertEqual("SEQUENCE(cola, colb)", col_no_step.sql())
-
- def test_from_csv(self):
- col_str = SF.from_csv("cola", "cola INT", dict(timestampFormat="dd/MM/yyyy"))
- self.assertEqual(
- "FROM_CSV(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql()
- )
- col = SF.from_csv(SF.col("cola"), "cola INT", dict(timestampFormat="dd/MM/yyyy"))
- self.assertEqual(
- "FROM_CSV(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col.sql()
- )
- col_no_option = SF.from_csv("cola", "cola INT")
- self.assertEqual("FROM_CSV(cola, 'cola INT')", col_no_option.sql())
-
- def test_aggregate(self):
- col_str = SF.aggregate("cola", SF.lit(0), lambda acc, x: acc + x, lambda acc: acc * 2)
- self.assertEqual("AGGREGATE(cola, 0, (acc, x) -> acc + x, acc -> acc * 2)", col_str.sql())
- col = SF.aggregate(SF.col("cola"), SF.lit(0), lambda acc, x: acc + x, lambda acc: acc * 2)
- self.assertEqual("AGGREGATE(cola, 0, (acc, x) -> acc + x, acc -> acc * 2)", col.sql())
- col_no_finish = SF.aggregate("cola", SF.lit(0), lambda acc, x: acc + x)
- self.assertEqual("AGGREGATE(cola, 0, (acc, x) -> acc + x)", col_no_finish.sql())
- col_custom_names = SF.aggregate(
- "cola",
- SF.lit(0),
- lambda accumulator, target: accumulator + target,
- lambda accumulator: accumulator * 2,
- )
- self.assertEqual(
- "AGGREGATE(cola, 0, (accumulator, target) -> accumulator + target, accumulator -> accumulator * 2)",
- col_custom_names.sql(),
- )
-
- def test_transform(self):
- col_str = SF.transform("cola", lambda x: x * 2)
- self.assertEqual("TRANSFORM(cola, x -> x * 2)", col_str.sql())
- col = SF.transform(SF.col("cola"), lambda x, i: x * i)
- self.assertEqual("TRANSFORM(cola, (x, i) -> x * i)", col.sql())
- col_custom_names = SF.transform("cola", lambda target, row_count: target * row_count)
-
- self.assertEqual(
- "TRANSFORM(cola, (target, row_count) -> target * row_count)", col_custom_names.sql()
- )
-
- def test_exists(self):
- col_str = SF.exists("cola", lambda x: x % 2 == 0)
- self.assertEqual("EXISTS(cola, x -> x % 2 = 0)", col_str.sql())
- col = SF.exists(SF.col("cola"), lambda x: x % 2 == 0)
- self.assertEqual("EXISTS(cola, x -> x % 2 = 0)", col.sql())
- col_custom_name = SF.exists("cola", lambda target: target > 0)
- self.assertEqual("EXISTS(cola, target -> target > 0)", col_custom_name.sql())
-
- def test_forall(self):
- col_str = SF.forall("cola", lambda x: x.rlike("foo"))
- self.assertEqual("FORALL(cola, x -> x RLIKE 'foo')", col_str.sql())
- col = SF.forall(SF.col("cola"), lambda x: x.rlike("foo"))
- self.assertEqual("FORALL(cola, x -> x RLIKE 'foo')", col.sql())
- col_custom_name = SF.forall("cola", lambda target: target.rlike("foo"))
- self.assertEqual("FORALL(cola, target -> target RLIKE 'foo')", col_custom_name.sql())
-
- def test_filter(self):
- col_str = SF.filter("cola", lambda x: SF.month(SF.to_date(x)) > SF.lit(6))
- self.assertEqual("FILTER(cola, x -> MONTH(TO_DATE(x)) > 6)", col_str.sql())
- col = SF.filter(SF.col("cola"), lambda x, i: SF.month(SF.to_date(x)) > SF.lit(i))
- self.assertEqual("FILTER(cola, (x, i) -> MONTH(TO_DATE(x)) > i)", col.sql())
- col_custom_names = SF.filter(
- "cola", lambda target, row_count: SF.month(SF.to_date(target)) > SF.lit(row_count)
- )
-
- self.assertEqual(
- "FILTER(cola, (target, row_count) -> MONTH(TO_DATE(target)) > row_count)",
- col_custom_names.sql(),
- )
-
- def test_zip_with(self):
- col_str = SF.zip_with("cola", "colb", lambda x, y: SF.concat_ws("_", x, y))
- self.assertEqual("ZIP_WITH(cola, colb, (x, y) -> CONCAT_WS('_', x, y))", col_str.sql())
- col = SF.zip_with(SF.col("cola"), SF.col("colb"), lambda x, y: SF.concat_ws("_", x, y))
- self.assertEqual("ZIP_WITH(cola, colb, (x, y) -> CONCAT_WS('_', x, y))", col.sql())
- col_custom_names = SF.zip_with("cola", "colb", lambda l, r: SF.concat_ws("_", l, r))
- self.assertEqual(
- "ZIP_WITH(cola, colb, (l, r) -> CONCAT_WS('_', l, r))", col_custom_names.sql()
- )
-
- def test_transform_keys(self):
- col_str = SF.transform_keys("cola", lambda k, v: SF.upper(k))
- self.assertEqual("TRANSFORM_KEYS(cola, (k, v) -> UPPER(k))", col_str.sql())
- col = SF.transform_keys(SF.col("cola"), lambda k, v: SF.upper(k))
- self.assertEqual("TRANSFORM_KEYS(cola, (k, v) -> UPPER(k))", col.sql())
- col_custom_names = SF.transform_keys("cola", lambda key, _: SF.upper(key))
- self.assertEqual("TRANSFORM_KEYS(cola, (key, _) -> UPPER(key))", col_custom_names.sql())
-
- def test_transform_values(self):
- col_str = SF.transform_values("cola", lambda k, v: SF.upper(v))
- self.assertEqual("TRANSFORM_VALUES(cola, (k, v) -> UPPER(v))", col_str.sql())
- col = SF.transform_values(SF.col("cola"), lambda k, v: SF.upper(v))
- self.assertEqual("TRANSFORM_VALUES(cola, (k, v) -> UPPER(v))", col.sql())
- col_custom_names = SF.transform_values("cola", lambda _, value: SF.upper(value))
- self.assertEqual(
- "TRANSFORM_VALUES(cola, (_, value) -> UPPER(value))", col_custom_names.sql()
- )
-
- def test_map_filter(self):
- col_str = SF.map_filter("cola", lambda k, v: k > v)
- self.assertEqual("MAP_FILTER(cola, (k, v) -> k > v)", col_str.sql())
- col = SF.map_filter(SF.col("cola"), lambda k, v: k > v)
- self.assertEqual("MAP_FILTER(cola, (k, v) -> k > v)", col.sql())
- col_custom_names = SF.map_filter("cola", lambda key, value: key > value)
- self.assertEqual("MAP_FILTER(cola, (key, value) -> key > value)", col_custom_names.sql())
-
- def test_map_zip_with(self):
- col = SF.map_zip_with("base", "ratio", lambda k, v1, v2: SF.round(v1 * v2, 2))
- self.assertEqual("MAP_ZIP_WITH(base, ratio, (k, v1, v2) -> ROUND(v1 * v2, 2))", col.sql())
diff --git a/tests/dataframe/unit/test_session.py b/tests/dataframe/unit/test_session.py
deleted file mode 100644
index 848c603..0000000
--- a/tests/dataframe/unit/test_session.py
+++ /dev/null
@@ -1,101 +0,0 @@
-import sqlglot
-from sqlglot.dataframe.sql import functions as F, types
-from sqlglot.dataframe.sql.session import SparkSession
-from tests.dataframe.unit.dataframe_sql_validator import DataFrameSQLValidator
-
-
-class TestDataframeSession(DataFrameSQLValidator):
- def test_cdf_one_row(self):
- df = self.spark.createDataFrame([[1, 2]], ["cola", "colb"])
- expected = "SELECT `a2`.`cola` AS `cola`, `a2`.`colb` AS `colb` FROM VALUES (1, 2) AS `a2`(`cola`, `colb`)"
- self.compare_sql(df, expected)
-
- def test_cdf_multiple_rows(self):
- df = self.spark.createDataFrame([[1, 2], [3, 4], [None, 6]], ["cola", "colb"])
- expected = "SELECT `a2`.`cola` AS `cola`, `a2`.`colb` AS `colb` FROM VALUES (1, 2), (3, 4), (NULL, 6) AS `a2`(`cola`, `colb`)"
- self.compare_sql(df, expected)
-
- def test_cdf_no_schema(self):
- df = self.spark.createDataFrame([[1, 2], [3, 4], [None, 6]])
- expected = "SELECT `a2`.`_1` AS `_1`, `a2`.`_2` AS `_2` FROM VALUES (1, 2), (3, 4), (NULL, 6) AS `a2`(`_1`, `_2`)"
- self.compare_sql(df, expected)
-
- def test_cdf_row_mixed_primitives(self):
- df = self.spark.createDataFrame([[1, 10.1, "test", False, None]])
- expected = "SELECT `a2`.`_1` AS `_1`, `a2`.`_2` AS `_2`, `a2`.`_3` AS `_3`, `a2`.`_4` AS `_4`, `a2`.`_5` AS `_5` FROM VALUES (1, 10.1, 'test', FALSE, NULL) AS `a2`(`_1`, `_2`, `_3`, `_4`, `_5`)"
- self.compare_sql(df, expected)
-
- def test_cdf_dict_rows(self):
- df = self.spark.createDataFrame([{"cola": 1, "colb": "test"}, {"cola": 2, "colb": "test2"}])
- expected = "SELECT `a2`.`cola` AS `cola`, `a2`.`colb` AS `colb` FROM VALUES (1, 'test'), (2, 'test2') AS `a2`(`cola`, `colb`)"
- self.compare_sql(df, expected)
-
- def test_cdf_str_schema(self):
- df = self.spark.createDataFrame([[1, "test"]], "cola: INT, colb: STRING")
- expected = "SELECT `a2`.`cola` AS `cola`, CAST(`a2`.`colb` AS STRING) AS `colb` FROM VALUES (1, 'test') AS `a2`(`cola`, `colb`)"
- self.compare_sql(df, expected)
-
- def test_typed_schema_basic(self):
- schema = types.StructType(
- [
- types.StructField("cola", types.IntegerType()),
- types.StructField("colb", types.StringType()),
- ]
- )
- df = self.spark.createDataFrame([[1, "test"]], schema)
- expected = "SELECT `a2`.`cola` AS `cola`, CAST(`a2`.`colb` AS STRING) AS `colb` FROM VALUES (1, 'test') AS `a2`(`cola`, `colb`)"
- self.compare_sql(df, expected)
-
- def test_typed_schema_nested(self):
- schema = types.StructType(
- [
- types.StructField(
- "cola",
- types.StructType(
- [
- types.StructField("sub_cola", types.IntegerType()),
- types.StructField("sub_colb", types.StringType()),
- ]
- ),
- )
- ]
- )
- df = self.spark.createDataFrame([[{"sub_cola": 1, "sub_colb": "test"}]], schema)
- expected = "SELECT `a2`.`cola` AS `cola` FROM VALUES (STRUCT(1 AS `sub_cola`, 'test' AS `sub_colb`)) AS `a2`(`cola`)"
-
- self.compare_sql(df, expected)
-
- def test_sql_select_only(self):
- query = "SELECT cola, colb FROM table"
- sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"}, dialect="spark")
- df = self.spark.sql(query)
- self.assertEqual(
- "SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`",
- df.sql(pretty=False)[0],
- )
-
- def test_sql_with_aggs(self):
- query = "SELECT cola, colb FROM table"
- sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"}, dialect="spark")
- df = self.spark.sql(query).groupBy(F.col("cola")).agg(F.sum("colb"))
- self.assertEqual(
- "WITH t26614 AS (SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`), t23454 AS (SELECT cola, colb FROM t26614) SELECT cola, SUM(colb) FROM t23454 GROUP BY cola",
- df.sql(pretty=False, optimize=False)[0],
- )
-
- def test_sql_create(self):
- query = "CREATE TABLE new_table AS WITH t1 AS (SELECT cola, colb FROM table) SELECT cola, colb, FROM t1"
- sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"}, dialect="spark")
- df = self.spark.sql(query)
- expected = "CREATE TABLE `new_table` AS SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`"
- self.compare_sql(df, expected)
-
- def test_sql_insert(self):
- query = "WITH t1 AS (SELECT cola, colb FROM table) INSERT INTO new_table SELECT cola, colb FROM t1"
- sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"}, dialect="spark")
- df = self.spark.sql(query)
- expected = "INSERT INTO `new_table` SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`"
- self.compare_sql(df, expected)
-
- def test_session_create_builder_patterns(self):
- self.assertEqual(SparkSession.builder.appName("abc").getOrCreate(), SparkSession())
diff --git a/tests/dataframe/unit/test_session_case_sensitivity.py b/tests/dataframe/unit/test_session_case_sensitivity.py
deleted file mode 100644
index 462edb6..0000000
--- a/tests/dataframe/unit/test_session_case_sensitivity.py
+++ /dev/null
@@ -1,87 +0,0 @@
-import sqlglot
-from sqlglot.dataframe.sql import functions as F
-from sqlglot.dataframe.sql.session import SparkSession
-from sqlglot.errors import OptimizeError
-from tests.dataframe.unit.dataframe_test_base import DataFrameTestBase
-
-
-class TestSessionCaseSensitivity(DataFrameTestBase):
- def setUp(self) -> None:
- super().setUp()
- self.spark = SparkSession.builder.config("sqlframe.dialect", "snowflake").getOrCreate()
-
- tests = [
- (
- "All lower no intention of CS",
- "test",
- "test",
- {"name": "VARCHAR"},
- "name",
- '''SELECT "TEST"."NAME" AS "NAME" FROM "TEST" AS "TEST"''',
- ),
- (
- "Table has CS while column does not",
- '"Test"',
- '"Test"',
- {"name": "VARCHAR"},
- "name",
- '''SELECT "Test"."NAME" AS "NAME" FROM "Test" AS "Test"''',
- ),
- (
- "Column has CS while table does not",
- "test",
- "test",
- {'"Name"': "VARCHAR"},
- '"Name"',
- '''SELECT "TEST"."Name" AS "Name" FROM "TEST" AS "TEST"''',
- ),
- (
- "Both Table and column have CS",
- '"Test"',
- '"Test"',
- {'"Name"': "VARCHAR"},
- '"Name"',
- '''SELECT "Test"."Name" AS "Name" FROM "Test" AS "Test"''',
- ),
- (
- "Lowercase CS table and column",
- '"test"',
- '"test"',
- {'"name"': "VARCHAR"},
- '"name"',
- '''SELECT "test"."name" AS "name" FROM "test" AS "test"''',
- ),
- (
- "CS table and column and query table but no CS in query column",
- '"test"',
- '"test"',
- {'"name"': "VARCHAR"},
- "name",
- OptimizeError(),
- ),
- (
- "CS table and column and query column but no CS in query table",
- '"test"',
- "test",
- {'"name"': "VARCHAR"},
- '"name"',
- OptimizeError(),
- ),
- ]
-
- def test_basic_case_sensitivity(self):
- for test_name, table_name, spark_table, schema, spark_column, expected in self.tests:
- with self.subTest(test_name):
- sqlglot.schema.add_table(table_name, schema, dialect=self.spark.dialect)
- df = self.spark.table(spark_table).select(F.col(spark_column))
- if isinstance(expected, OptimizeError):
- with self.assertRaises(OptimizeError):
- df.sql()
- else:
- self.compare_sql(df, expected)
-
- def test_alias(self):
- col = F.col('"Name"')
- self.assertEqual(col.sql(dialect=self.spark.dialect), '"Name"')
- self.assertEqual(col.alias("nAME").sql(dialect=self.spark.dialect), '"Name" AS NAME')
- self.assertEqual(col.alias('"nAME"').sql(dialect=self.spark.dialect), '"Name" AS "nAME"')
diff --git a/tests/dataframe/unit/test_types.py b/tests/dataframe/unit/test_types.py
deleted file mode 100644
index 52f5d72..0000000
--- a/tests/dataframe/unit/test_types.py
+++ /dev/null
@@ -1,73 +0,0 @@
-import unittest
-
-from sqlglot.dataframe.sql import types
-
-
-class TestDataframeTypes(unittest.TestCase):
- def test_string(self):
- self.assertEqual("string", types.StringType().simpleString())
-
- def test_char(self):
- self.assertEqual("char(100)", types.CharType(100).simpleString())
-
- def test_varchar(self):
- self.assertEqual("varchar(65)", types.VarcharType(65).simpleString())
-
- def test_binary(self):
- self.assertEqual("binary", types.BinaryType().simpleString())
-
- def test_boolean(self):
- self.assertEqual("boolean", types.BooleanType().simpleString())
-
- def test_date(self):
- self.assertEqual("date", types.DateType().simpleString())
-
- def test_timestamp(self):
- self.assertEqual("timestamp", types.TimestampType().simpleString())
-
- def test_timestamp_ntz(self):
- self.assertEqual("timestamp_ntz", types.TimestampNTZType().simpleString())
-
- def test_decimal(self):
- self.assertEqual("decimal(10, 3)", types.DecimalType(10, 3).simpleString())
-
- def test_double(self):
- self.assertEqual("double", types.DoubleType().simpleString())
-
- def test_float(self):
- self.assertEqual("float", types.FloatType().simpleString())
-
- def test_byte(self):
- self.assertEqual("tinyint", types.ByteType().simpleString())
-
- def test_integer(self):
- self.assertEqual("int", types.IntegerType().simpleString())
-
- def test_long(self):
- self.assertEqual("bigint", types.LongType().simpleString())
-
- def test_short(self):
- self.assertEqual("smallint", types.ShortType().simpleString())
-
- def test_array(self):
- self.assertEqual("array<int>", types.ArrayType(types.IntegerType()).simpleString())
-
- def test_map(self):
- self.assertEqual(
- "map<int, string>",
- types.MapType(types.IntegerType(), types.StringType()).simpleString(),
- )
-
- def test_struct_field(self):
- self.assertEqual("cola:int", types.StructField("cola", types.IntegerType()).simpleString())
-
- def test_struct_type(self):
- self.assertEqual(
- "struct<cola:int, colb:string>",
- types.StructType(
- [
- types.StructField("cola", types.IntegerType()),
- types.StructField("colb", types.StringType()),
- ]
- ).simpleString(),
- )
diff --git a/tests/dataframe/unit/test_window.py b/tests/dataframe/unit/test_window.py
deleted file mode 100644
index 9c4c897..0000000
--- a/tests/dataframe/unit/test_window.py
+++ /dev/null
@@ -1,75 +0,0 @@
-from sqlglot.dataframe.sql import functions as F
-from sqlglot.dataframe.sql.window import Window, WindowSpec
-from tests.dataframe.unit.dataframe_test_base import DataFrameTestBase
-
-
-class TestDataframeWindow(DataFrameTestBase):
- def test_window_spec_partition_by(self):
- partition_by = WindowSpec().partitionBy(F.col("cola"), F.col("colb"))
- self.assertEqual("OVER (PARTITION BY cola, colb)", partition_by.sql())
-
- def test_window_spec_order_by(self):
- order_by = WindowSpec().orderBy("cola", "colb")
- self.assertEqual("OVER (ORDER BY cola, colb)", order_by.sql())
-
- def test_window_spec_rows_between(self):
- rows_between = WindowSpec().rowsBetween(3, 5)
- self.assertEqual("OVER (ROWS BETWEEN 3 PRECEDING AND 5 FOLLOWING)", rows_between.sql())
-
- def test_window_spec_range_between(self):
- range_between = WindowSpec().rangeBetween(3, 5)
- self.assertEqual("OVER (RANGE BETWEEN 3 PRECEDING AND 5 FOLLOWING)", range_between.sql())
-
- def test_window_partition_by(self):
- partition_by = Window.partitionBy(F.col("cola"), F.col("colb"))
- self.assertEqual("OVER (PARTITION BY cola, colb)", partition_by.sql())
-
- def test_window_order_by(self):
- order_by = Window.orderBy("cola", "colb")
- self.assertEqual("OVER (ORDER BY cola, colb)", order_by.sql())
-
- def test_window_rows_between(self):
- rows_between = Window.rowsBetween(3, 5)
- self.assertEqual("OVER (ROWS BETWEEN 3 PRECEDING AND 5 FOLLOWING)", rows_between.sql())
-
- def test_window_range_between(self):
- range_between = Window.rangeBetween(3, 5)
- self.assertEqual("OVER (RANGE BETWEEN 3 PRECEDING AND 5 FOLLOWING)", range_between.sql())
-
- def test_window_rows_unbounded(self):
- rows_between_unbounded_start = Window.rowsBetween(Window.unboundedPreceding, 2)
- self.assertEqual(
- "OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)",
- rows_between_unbounded_start.sql(),
- )
- rows_between_unbounded_end = Window.rowsBetween(1, Window.unboundedFollowing)
- self.assertEqual(
- "OVER (ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)",
- rows_between_unbounded_end.sql(),
- )
- rows_between_unbounded_both = Window.rowsBetween(
- Window.unboundedPreceding, Window.unboundedFollowing
- )
- self.assertEqual(
- "OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)",
- rows_between_unbounded_both.sql(),
- )
-
- def test_window_range_unbounded(self):
- range_between_unbounded_start = Window.rangeBetween(Window.unboundedPreceding, 2)
- self.assertEqual(
- "OVER (RANGE BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)",
- range_between_unbounded_start.sql(),
- )
- range_between_unbounded_end = Window.rangeBetween(1, Window.unboundedFollowing)
- self.assertEqual(
- "OVER (RANGE BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)",
- range_between_unbounded_end.sql(),
- )
- range_between_unbounded_both = Window.rangeBetween(
- Window.unboundedPreceding, Window.unboundedFollowing
- )
- self.assertEqual(
- "OVER (RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)",
- range_between_unbounded_both.sql(),
- )
diff --git a/tests/dialects/test_databricks.py b/tests/dialects/test_databricks.py
index 14a6bf3..f050cfa 100644
--- a/tests/dialects/test_databricks.py
+++ b/tests/dialects/test_databricks.py
@@ -20,7 +20,6 @@ class TestDatabricks(Validator):
self.validate_identity("SELECT CAST('23:00:00' AS INTERVAL MINUTE TO SECOND)")
self.validate_identity("CREATE TABLE target SHALLOW CLONE source")
self.validate_identity("INSERT INTO a REPLACE WHERE cond VALUES (1), (2)")
- self.validate_identity("SELECT c1 : price")
self.validate_identity("CREATE FUNCTION a.b(x INT) RETURNS INT RETURN x + 1")
self.validate_identity("CREATE FUNCTION a AS b")
self.validate_identity("SELECT ${x} FROM ${y} WHERE ${z} > 1")
@@ -68,6 +67,20 @@ class TestDatabricks(Validator):
},
)
+ self.validate_all(
+ "SELECT X'1A2B'",
+ read={
+ "spark2": "SELECT X'1A2B'",
+ "spark": "SELECT X'1A2B'",
+ "databricks": "SELECT x'1A2B'",
+ },
+ write={
+ "spark2": "SELECT X'1A2B'",
+ "spark": "SELECT X'1A2B'",
+ "databricks": "SELECT X'1A2B'",
+ },
+ )
+
with self.assertRaises(ParseError):
transpile(
"CREATE FUNCTION add_one(x INT) RETURNS INT LANGUAGE PYTHON AS $foo$def add_one(x):\n return x+1$$",
@@ -82,37 +95,33 @@ class TestDatabricks(Validator):
# https://docs.databricks.com/sql/language-manual/functions/colonsign.html
def test_json(self):
- self.validate_identity("""SELECT c1 : price FROM VALUES ('{ "price": 5 }') AS T(c1)""")
-
- self.validate_all(
+ self.validate_identity(
+ """SELECT c1 : price FROM VALUES ('{ "price": 5 }') AS T(c1)""",
+ """SELECT GET_JSON_OBJECT(c1, '$.price') FROM VALUES ('{ "price": 5 }') AS T(c1)""",
+ )
+ self.validate_identity(
"""SELECT c1:['price'] FROM VALUES('{ "price": 5 }') AS T(c1)""",
- write={
- "databricks": """SELECT c1 : ARRAY('price') FROM VALUES ('{ "price": 5 }') AS T(c1)""",
- },
+ """SELECT GET_JSON_OBJECT(c1, '$.price') FROM VALUES ('{ "price": 5 }') AS T(c1)""",
)
- self.validate_all(
+ self.validate_identity(
"""SELECT c1:item[1].price FROM VALUES('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }') AS T(c1)""",
- write={
- "databricks": """SELECT c1 : item[1].price FROM VALUES ('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }') AS T(c1)""",
- },
+ """SELECT GET_JSON_OBJECT(c1, '$.item[1].price') FROM VALUES ('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }') AS T(c1)""",
)
- self.validate_all(
+ self.validate_identity(
"""SELECT c1:item[*].price FROM VALUES('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }') AS T(c1)""",
- write={
- "databricks": """SELECT c1 : item[*].price FROM VALUES ('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }') AS T(c1)""",
- },
+ """SELECT GET_JSON_OBJECT(c1, '$.item[*].price') FROM VALUES ('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }') AS T(c1)""",
)
- self.validate_all(
+ self.validate_identity(
"""SELECT from_json(c1:item[*].price, 'ARRAY<DOUBLE>')[0] FROM VALUES('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }') AS T(c1)""",
- write={
- "databricks": """SELECT FROM_JSON(c1 : item[*].price, 'ARRAY<DOUBLE>')[0] FROM VALUES ('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }') AS T(c1)""",
- },
+ """SELECT FROM_JSON(GET_JSON_OBJECT(c1, '$.item[*].price'), 'ARRAY<DOUBLE>')[0] FROM VALUES ('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }') AS T(c1)""",
)
- self.validate_all(
+ self.validate_identity(
"""SELECT inline(from_json(c1:item[*], 'ARRAY<STRUCT<model STRING, price DOUBLE>>')) FROM VALUES('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }') AS T(c1)""",
- write={
- "databricks": """SELECT INLINE(FROM_JSON(c1 : item[*], 'ARRAY<STRUCT<model STRING, price DOUBLE>>')) FROM VALUES ('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }') AS T(c1)""",
- },
+ """SELECT INLINE(FROM_JSON(GET_JSON_OBJECT(c1, '$.item[*]'), 'ARRAY<STRUCT<model STRING, price DOUBLE>>')) FROM VALUES ('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }') AS T(c1)""",
+ )
+ self.validate_identity(
+ "SELECT c1 : price",
+ "SELECT GET_JSON_OBJECT(c1, '$.price')",
)
def test_datediff(self):
diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py
index 77306dc..9888a5d 100644
--- a/tests/dialects/test_dialect.py
+++ b/tests/dialects/test_dialect.py
@@ -1163,7 +1163,7 @@ class TestDialect(Validator):
read={
"bigquery": "JSON_EXTRACT(x, '$.y')",
"duckdb": "x -> 'y'",
- "doris": "x -> '$.y'",
+ "doris": "JSON_EXTRACT(x, '$.y')",
"mysql": "JSON_EXTRACT(x, '$.y')",
"postgres": "x->'y'",
"presto": "JSON_EXTRACT(x, '$.y')",
@@ -1174,7 +1174,7 @@ class TestDialect(Validator):
write={
"bigquery": "JSON_EXTRACT(x, '$.y')",
"clickhouse": "JSONExtractString(x, 'y')",
- "doris": "x -> '$.y'",
+ "doris": "JSON_EXTRACT(x, '$.y')",
"duckdb": "x -> '$.y'",
"mysql": "JSON_EXTRACT(x, '$.y')",
"oracle": "JSON_EXTRACT(x, '$.y')",
@@ -1218,7 +1218,7 @@ class TestDialect(Validator):
read={
"bigquery": "JSON_EXTRACT(x, '$.y[0].z')",
"duckdb": "x -> '$.y[0].z'",
- "doris": "x -> '$.y[0].z'",
+ "doris": "JSON_EXTRACT(x, '$.y[0].z')",
"mysql": "JSON_EXTRACT(x, '$.y[0].z')",
"presto": "JSON_EXTRACT(x, '$.y[0].z')",
"snowflake": "GET_PATH(x, 'y[0].z')",
@@ -1228,7 +1228,7 @@ class TestDialect(Validator):
write={
"bigquery": "JSON_EXTRACT(x, '$.y[0].z')",
"clickhouse": "JSONExtractString(x, 'y', 1, 'z')",
- "doris": "x -> '$.y[0].z'",
+ "doris": "JSON_EXTRACT(x, '$.y[0].z')",
"duckdb": "x -> '$.y[0].z'",
"mysql": "JSON_EXTRACT(x, '$.y[0].z')",
"oracle": "JSON_EXTRACT(x, '$.y[0].z')",
diff --git a/tests/dialects/test_doris.py b/tests/dialects/test_doris.py
index 035289b..f7fce02 100644
--- a/tests/dialects/test_doris.py
+++ b/tests/dialects/test_doris.py
@@ -14,7 +14,9 @@ class TestDoris(Validator):
)
self.validate_all(
"SELECT MAX_BY(a, b), MIN_BY(c, d)",
- read={"clickhouse": "SELECT argMax(a, b), argMin(c, d)"},
+ read={
+ "clickhouse": "SELECT argMax(a, b), argMin(c, d)",
+ },
)
self.validate_all(
"SELECT ARRAY_SUM(x -> x * x, ARRAY(2, 3))",
@@ -36,6 +38,16 @@ class TestDoris(Validator):
"oracle": "ADD_MONTHS(d, n)",
},
)
+ self.validate_all(
+ """SELECT JSON_EXTRACT(CAST('{"key": 1}' AS JSONB), '$.key')""",
+ read={
+ "postgres": """SELECT '{"key": 1}'::jsonb ->> 'key'""",
+ },
+ write={
+ "doris": """SELECT JSON_EXTRACT(CAST('{"key": 1}' AS JSONB), '$.key')""",
+ "postgres": """SELECT JSON_EXTRACT_PATH(CAST('{"key": 1}' AS JSONB), 'key')""",
+ },
+ )
def test_identity(self):
self.validate_identity("COALECSE(a, b, c, d)")
diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py
index 84fb3c2..591b5dd 100644
--- a/tests/dialects/test_mysql.py
+++ b/tests/dialects/test_mysql.py
@@ -155,6 +155,10 @@ class TestMySQL(Validator):
"""SELECT * FROM foo WHERE 3 MEMBER OF(info->'$.value')""",
"""SELECT * FROM foo WHERE 3 MEMBER OF(JSON_EXTRACT(info, '$.value'))""",
)
+ self.validate_identity(
+ "SELECT 1 AS row",
+ "SELECT 1 AS `row`",
+ )
# Index hints
self.validate_identity(
@@ -334,7 +338,7 @@ class TestMySQL(Validator):
write_CC = {
"bigquery": "SELECT 0xCC",
"clickhouse": "SELECT 0xCC",
- "databricks": "SELECT 204",
+ "databricks": "SELECT X'CC'",
"drill": "SELECT 204",
"duckdb": "SELECT 204",
"hive": "SELECT 204",
@@ -355,7 +359,7 @@ class TestMySQL(Validator):
write_CC_with_leading_zeros = {
"bigquery": "SELECT 0x0000CC",
"clickhouse": "SELECT 0x0000CC",
- "databricks": "SELECT 204",
+ "databricks": "SELECT X'0000CC'",
"drill": "SELECT 204",
"duckdb": "SELECT 204",
"hive": "SELECT 204",
diff --git a/tests/dialects/test_postgres.py b/tests/dialects/test_postgres.py
index a8a6c12..8ba4e96 100644
--- a/tests/dialects/test_postgres.py
+++ b/tests/dialects/test_postgres.py
@@ -38,8 +38,6 @@ class TestPostgres(Validator):
self.validate_identity("CAST(x AS TSTZMULTIRANGE)")
self.validate_identity("CAST(x AS DATERANGE)")
self.validate_identity("CAST(x AS DATEMULTIRANGE)")
- self.validate_identity("SELECT ARRAY[1, 2, 3] @> ARRAY[1, 2]")
- self.validate_identity("SELECT ARRAY[1, 2, 3] <@ ARRAY[1, 2]")
self.validate_identity("x$")
self.validate_identity("SELECT ARRAY[1, 2, 3]")
self.validate_identity("SELECT ARRAY(SELECT 1)")
@@ -65,6 +63,10 @@ class TestPostgres(Validator):
self.validate_identity("SELECT CURRENT_USER")
self.validate_identity("SELECT * FROM ONLY t1")
self.validate_identity(
+ "SELECT ARRAY[1, 2, 3] <@ ARRAY[1, 2]",
+ "SELECT ARRAY[1, 2] @> ARRAY[1, 2, 3]",
+ )
+ self.validate_identity(
"""UPDATE "x" SET "y" = CAST('0 days 60.000000 seconds' AS INTERVAL) WHERE "x"."id" IN (2, 3)"""
)
self.validate_identity(
@@ -326,6 +328,17 @@ class TestPostgres(Validator):
)
self.validate_all(
+ "SELECT ARRAY[1, 2, 3] @> ARRAY[1, 2]",
+ read={
+ "duckdb": "SELECT ARRAY_HAS_ALL([1, 2, 3], [1, 2])",
+ },
+ write={
+ "duckdb": "SELECT ARRAY_HAS_ALL([1, 2, 3], [1, 2])",
+ "mysql": UnsupportedError,
+ "postgres": "SELECT ARRAY[1, 2, 3] @> ARRAY[1, 2]",
+ },
+ )
+ self.validate_all(
"SELECT REGEXP_REPLACE('mr .', '[^a-zA-Z]', '', 'g')",
write={
"duckdb": "SELECT REGEXP_REPLACE('mr .', '[^a-zA-Z]', '', 'g')",
@@ -741,6 +754,9 @@ class TestPostgres(Validator):
self.validate_identity("ALTER TABLE t1 SET TABLESPACE tablespace")
self.validate_identity("ALTER TABLE t1 SET (fillfactor = 5, autovacuum_enabled = TRUE)")
self.validate_identity(
+ "CREATE FUNCTION pymax(a INT, b INT) RETURNS INT LANGUAGE plpython3u AS $$\n if a > b:\n return a\n return b\n$$",
+ )
+ self.validate_identity(
"CREATE TABLE t (vid INT NOT NULL, CONSTRAINT ht_vid_nid_fid_idx EXCLUDE (INT4RANGE(vid, nid) WITH &&, INT4RANGE(fid, fid, '[]') WITH &&))"
)
self.validate_identity(
diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py
index d3c47af..f8c2ea1 100644
--- a/tests/dialects/test_snowflake.py
+++ b/tests/dialects/test_snowflake.py
@@ -10,6 +10,11 @@ class TestSnowflake(Validator):
dialect = "snowflake"
def test_snowflake(self):
+ self.validate_identity(
+ "transform(x, a int -> a + a + 1)",
+ "TRANSFORM(x, a -> CAST(a AS INT) + CAST(a AS INT) + 1)",
+ )
+
self.validate_all(
"ARRAY_CONSTRUCT_COMPACT(1, null, 2)",
write={
@@ -321,10 +326,12 @@ WHERE
"""SELECT PARSE_JSON('{"fruit":"banana"}'):fruit""",
write={
"bigquery": """SELECT JSON_EXTRACT(PARSE_JSON('{"fruit":"banana"}'), '$.fruit')""",
+ "databricks": """SELECT GET_JSON_OBJECT('{"fruit":"banana"}', '$.fruit')""",
"duckdb": """SELECT JSON('{"fruit":"banana"}') -> '$.fruit'""",
"mysql": """SELECT JSON_EXTRACT('{"fruit":"banana"}', '$.fruit')""",
"presto": """SELECT JSON_EXTRACT(JSON_PARSE('{"fruit":"banana"}'), '$.fruit')""",
"snowflake": """SELECT GET_PATH(PARSE_JSON('{"fruit":"banana"}'), 'fruit')""",
+ "spark": """SELECT GET_JSON_OBJECT('{"fruit":"banana"}', '$.fruit')""",
"tsql": """SELECT ISNULL(JSON_QUERY('{"fruit":"banana"}', '$.fruit'), JSON_VALUE('{"fruit":"banana"}', '$.fruit'))""",
},
)
@@ -1198,6 +1205,8 @@ WHERE
self.validate_identity("CREATE TABLE IDENTIFIER('foo') (COLUMN1 VARCHAR, COLUMN2 VARCHAR)")
self.validate_identity("CREATE TABLE IDENTIFIER($foo) (col1 VARCHAR, col2 VARCHAR)")
self.validate_identity("CREATE TAG cost_center ALLOWED_VALUES 'a', 'b'")
+ self.validate_identity("CREATE WAREHOUSE x").this.assert_is(exp.Identifier)
+ self.validate_identity("CREATE STREAMLIT x").this.assert_is(exp.Identifier)
self.validate_identity(
"CREATE OR REPLACE TAG IF NOT EXISTS cost_center COMMENT='cost_center tag'"
).this.assert_is(exp.Identifier)
@@ -1825,7 +1834,7 @@ STORAGE_AWS_ROLE_ARN='arn:aws:iam::001234567890:role/myrole'
ENABLED=TRUE
STORAGE_ALLOWED_LOCATIONS=('s3://mybucket1/path1/', 's3://mybucket2/path2/')""",
pretty=True,
- )
+ ).this.assert_is(exp.Identifier)
def test_swap(self):
ast = parse_one("ALTER TABLE a SWAP WITH b", read="snowflake")
diff --git a/tests/dialects/test_teradata.py b/tests/dialects/test_teradata.py
index 010b683..74d5f88 100644
--- a/tests/dialects/test_teradata.py
+++ b/tests/dialects/test_teradata.py
@@ -38,7 +38,7 @@ class TestTeradata(Validator):
"UPDATE A FROM schema.tableA AS A, (SELECT col1 FROM schema.tableA GROUP BY col1) AS B SET col2 = '' WHERE A.col1 = B.col1",
write={
"teradata": "UPDATE A FROM schema.tableA AS A, (SELECT col1 FROM schema.tableA GROUP BY col1) AS B SET col2 = '' WHERE A.col1 = B.col1",
- "mysql": "UPDATE A SET col2 = '' FROM schema.tableA AS A, (SELECT col1 FROM schema.tableA GROUP BY col1) AS B WHERE A.col1 = B.col1",
+ "mysql": "UPDATE A SET col2 = '' FROM `schema`.tableA AS A, (SELECT col1 FROM `schema`.tableA GROUP BY col1) AS B WHERE A.col1 = B.col1",
},
)
diff --git a/tests/fixtures/identity.sql b/tests/fixtures/identity.sql
index 13a6153..e31031d 100644
--- a/tests/fixtures/identity.sql
+++ b/tests/fixtures/identity.sql
@@ -870,4 +870,5 @@ SELECT enum
SELECT unlogged
SELECT name
SELECT copy
-SELECT rollup \ No newline at end of file
+SELECT rollup
+SELECT unnest