summaryrefslogtreecommitdiffstats
path: root/tests/dataframe/integration/test_grouped_data.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/dataframe/integration/test_grouped_data.py')
-rw-r--r--tests/dataframe/integration/test_grouped_data.py71
1 files changed, 0 insertions, 71 deletions
diff --git a/tests/dataframe/integration/test_grouped_data.py b/tests/dataframe/integration/test_grouped_data.py
deleted file mode 100644
index 2768dda..0000000
--- a/tests/dataframe/integration/test_grouped_data.py
+++ /dev/null
@@ -1,71 +0,0 @@
-from pyspark.sql import functions as F
-
-from sqlglot.dataframe.sql import functions as SF
-from tests.dataframe.integration.dataframe_validator import DataFrameValidator
-
-
-class TestDataframeFunc(DataFrameValidator):
- def test_group_by(self):
- df_employee = self.df_spark_employee.groupBy(self.df_spark_employee.age).agg(
- F.min(self.df_spark_employee.employee_id)
- )
- dfs_employee = self.df_sqlglot_employee.groupBy(self.df_sqlglot_employee.age).agg(
- SF.min(self.df_sqlglot_employee.employee_id)
- )
- self.compare_spark_with_sqlglot(df_employee, dfs_employee, skip_schema_compare=True)
-
- def test_group_by_where_non_aggregate(self):
- df_employee = (
- self.df_spark_employee.groupBy(self.df_spark_employee.age)
- .agg(F.min(self.df_spark_employee.employee_id).alias("min_employee_id"))
- .where(F.col("age") > F.lit(50))
- )
- dfs_employee = (
- self.df_sqlglot_employee.groupBy(self.df_sqlglot_employee.age)
- .agg(SF.min(self.df_sqlglot_employee.employee_id).alias("min_employee_id"))
- .where(SF.col("age") > SF.lit(50))
- )
- self.compare_spark_with_sqlglot(df_employee, dfs_employee)
-
- def test_group_by_where_aggregate_like_having(self):
- df_employee = (
- self.df_spark_employee.groupBy(self.df_spark_employee.age)
- .agg(F.min(self.df_spark_employee.employee_id).alias("min_employee_id"))
- .where(F.col("min_employee_id") > F.lit(1))
- )
- dfs_employee = (
- self.df_sqlglot_employee.groupBy(self.df_sqlglot_employee.age)
- .agg(SF.min(self.df_sqlglot_employee.employee_id).alias("min_employee_id"))
- .where(SF.col("min_employee_id") > SF.lit(1))
- )
- self.compare_spark_with_sqlglot(df_employee, dfs_employee)
-
- def test_count(self):
- df = self.df_spark_employee.groupBy(self.df_spark_employee.age).count()
- dfs = self.df_sqlglot_employee.groupBy(self.df_sqlglot_employee.age).count()
- self.compare_spark_with_sqlglot(df, dfs)
-
- def test_mean(self):
- df = self.df_spark_employee.groupBy().mean("age", "store_id")
- dfs = self.df_sqlglot_employee.groupBy().mean("age", "store_id")
- self.compare_spark_with_sqlglot(df, dfs)
-
- def test_avg(self):
- df = self.df_spark_employee.groupBy("age").avg("store_id")
- dfs = self.df_sqlglot_employee.groupBy("age").avg("store_id")
- self.compare_spark_with_sqlglot(df, dfs)
-
- def test_max(self):
- df = self.df_spark_employee.groupBy("age").max("store_id")
- dfs = self.df_sqlglot_employee.groupBy("age").max("store_id")
- self.compare_spark_with_sqlglot(df, dfs)
-
- def test_min(self):
- df = self.df_spark_employee.groupBy("age").min("store_id")
- dfs = self.df_sqlglot_employee.groupBy("age").min("store_id")
- self.compare_spark_with_sqlglot(df, dfs)
-
- def test_sum(self):
- df = self.df_spark_employee.groupBy("age").sum("store_id")
- dfs = self.df_sqlglot_employee.groupBy("age").sum("store_id")
- self.compare_spark_with_sqlglot(df, dfs)