summaryrefslogtreecommitdiffstats
path: root/tests/dataframe/integration/test_grouped_data.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/dataframe/integration/test_grouped_data.py')
-rw-r--r--tests/dataframe/integration/test_grouped_data.py71
1 files changed, 71 insertions, 0 deletions
diff --git a/tests/dataframe/integration/test_grouped_data.py b/tests/dataframe/integration/test_grouped_data.py
new file mode 100644
index 0000000..2768dda
--- /dev/null
+++ b/tests/dataframe/integration/test_grouped_data.py
@@ -0,0 +1,71 @@
+from pyspark.sql import functions as F
+
+from sqlglot.dataframe.sql import functions as SF
+from tests.dataframe.integration.dataframe_validator import DataFrameValidator
+
+
+class TestDataframeFunc(DataFrameValidator):
+ def test_group_by(self):
+ df_employee = self.df_spark_employee.groupBy(self.df_spark_employee.age).agg(
+ F.min(self.df_spark_employee.employee_id)
+ )
+ dfs_employee = self.df_sqlglot_employee.groupBy(self.df_sqlglot_employee.age).agg(
+ SF.min(self.df_sqlglot_employee.employee_id)
+ )
+ self.compare_spark_with_sqlglot(df_employee, dfs_employee, skip_schema_compare=True)
+
+ def test_group_by_where_non_aggregate(self):
+ df_employee = (
+ self.df_spark_employee.groupBy(self.df_spark_employee.age)
+ .agg(F.min(self.df_spark_employee.employee_id).alias("min_employee_id"))
+ .where(F.col("age") > F.lit(50))
+ )
+ dfs_employee = (
+ self.df_sqlglot_employee.groupBy(self.df_sqlglot_employee.age)
+ .agg(SF.min(self.df_sqlglot_employee.employee_id).alias("min_employee_id"))
+ .where(SF.col("age") > SF.lit(50))
+ )
+ self.compare_spark_with_sqlglot(df_employee, dfs_employee)
+
+ def test_group_by_where_aggregate_like_having(self):
+ df_employee = (
+ self.df_spark_employee.groupBy(self.df_spark_employee.age)
+ .agg(F.min(self.df_spark_employee.employee_id).alias("min_employee_id"))
+ .where(F.col("min_employee_id") > F.lit(1))
+ )
+ dfs_employee = (
+ self.df_sqlglot_employee.groupBy(self.df_sqlglot_employee.age)
+ .agg(SF.min(self.df_sqlglot_employee.employee_id).alias("min_employee_id"))
+ .where(SF.col("min_employee_id") > SF.lit(1))
+ )
+ self.compare_spark_with_sqlglot(df_employee, dfs_employee)
+
+ def test_count(self):
+ df = self.df_spark_employee.groupBy(self.df_spark_employee.age).count()
+ dfs = self.df_sqlglot_employee.groupBy(self.df_sqlglot_employee.age).count()
+ self.compare_spark_with_sqlglot(df, dfs)
+
+ def test_mean(self):
+ df = self.df_spark_employee.groupBy().mean("age", "store_id")
+ dfs = self.df_sqlglot_employee.groupBy().mean("age", "store_id")
+ self.compare_spark_with_sqlglot(df, dfs)
+
+ def test_avg(self):
+ df = self.df_spark_employee.groupBy("age").avg("store_id")
+ dfs = self.df_sqlglot_employee.groupBy("age").avg("store_id")
+ self.compare_spark_with_sqlglot(df, dfs)
+
+ def test_max(self):
+ df = self.df_spark_employee.groupBy("age").max("store_id")
+ dfs = self.df_sqlglot_employee.groupBy("age").max("store_id")
+ self.compare_spark_with_sqlglot(df, dfs)
+
+ def test_min(self):
+ df = self.df_spark_employee.groupBy("age").min("store_id")
+ dfs = self.df_sqlglot_employee.groupBy("age").min("store_id")
+ self.compare_spark_with_sqlglot(df, dfs)
+
+ def test_sum(self):
+ df = self.df_spark_employee.groupBy("age").sum("store_id")
+ dfs = self.df_sqlglot_employee.groupBy("age").sum("store_id")
+ self.compare_spark_with_sqlglot(df, dfs)