summaryrefslogtreecommitdiffstats
path: root/tests/dataframe/integration/test_grouped_data.py
blob: 2768dda46a4e2817ee7a9433e878e7d1422bffda (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from pyspark.sql import functions as F

from sqlglot.dataframe.sql import functions as SF
from tests.dataframe.integration.dataframe_validator import DataFrameValidator


class TestDataframeFunc(DataFrameValidator):
    def test_group_by(self):
        df_employee = self.df_spark_employee.groupBy(self.df_spark_employee.age).agg(
            F.min(self.df_spark_employee.employee_id)
        )
        dfs_employee = self.df_sqlglot_employee.groupBy(self.df_sqlglot_employee.age).agg(
            SF.min(self.df_sqlglot_employee.employee_id)
        )
        self.compare_spark_with_sqlglot(df_employee, dfs_employee, skip_schema_compare=True)

    def test_group_by_where_non_aggregate(self):
        df_employee = (
            self.df_spark_employee.groupBy(self.df_spark_employee.age)
            .agg(F.min(self.df_spark_employee.employee_id).alias("min_employee_id"))
            .where(F.col("age") > F.lit(50))
        )
        dfs_employee = (
            self.df_sqlglot_employee.groupBy(self.df_sqlglot_employee.age)
            .agg(SF.min(self.df_sqlglot_employee.employee_id).alias("min_employee_id"))
            .where(SF.col("age") > SF.lit(50))
        )
        self.compare_spark_with_sqlglot(df_employee, dfs_employee)

    def test_group_by_where_aggregate_like_having(self):
        df_employee = (
            self.df_spark_employee.groupBy(self.df_spark_employee.age)
            .agg(F.min(self.df_spark_employee.employee_id).alias("min_employee_id"))
            .where(F.col("min_employee_id") > F.lit(1))
        )
        dfs_employee = (
            self.df_sqlglot_employee.groupBy(self.df_sqlglot_employee.age)
            .agg(SF.min(self.df_sqlglot_employee.employee_id).alias("min_employee_id"))
            .where(SF.col("min_employee_id") > SF.lit(1))
        )
        self.compare_spark_with_sqlglot(df_employee, dfs_employee)

    def test_count(self):
        df = self.df_spark_employee.groupBy(self.df_spark_employee.age).count()
        dfs = self.df_sqlglot_employee.groupBy(self.df_sqlglot_employee.age).count()
        self.compare_spark_with_sqlglot(df, dfs)

    def test_mean(self):
        df = self.df_spark_employee.groupBy().mean("age", "store_id")
        dfs = self.df_sqlglot_employee.groupBy().mean("age", "store_id")
        self.compare_spark_with_sqlglot(df, dfs)

    def test_avg(self):
        df = self.df_spark_employee.groupBy("age").avg("store_id")
        dfs = self.df_sqlglot_employee.groupBy("age").avg("store_id")
        self.compare_spark_with_sqlglot(df, dfs)

    def test_max(self):
        df = self.df_spark_employee.groupBy("age").max("store_id")
        dfs = self.df_sqlglot_employee.groupBy("age").max("store_id")
        self.compare_spark_with_sqlglot(df, dfs)

    def test_min(self):
        df = self.df_spark_employee.groupBy("age").min("store_id")
        dfs = self.df_sqlglot_employee.groupBy("age").min("store_id")
        self.compare_spark_with_sqlglot(df, dfs)

    def test_sum(self):
        df = self.df_spark_employee.groupBy("age").sum("store_id")
        dfs = self.df_sqlglot_employee.groupBy("age").sum("store_id")
        self.compare_spark_with_sqlglot(df, dfs)