Edit on GitHub

sqlglot.dataframe.sql

 1from sqlglot.dataframe.sql.column import Column
 2from sqlglot.dataframe.sql.dataframe import DataFrame, DataFrameNaFunctions
 3from sqlglot.dataframe.sql.group import GroupedData
 4from sqlglot.dataframe.sql.readwriter import DataFrameReader, DataFrameWriter
 5from sqlglot.dataframe.sql.session import SparkSession
 6from sqlglot.dataframe.sql.window import Window, WindowSpec
 7
 8__all__ = [
 9    "SparkSession",
10    "DataFrame",
11    "GroupedData",
12    "Column",
13    "DataFrameNaFunctions",
14    "Window",
15    "WindowSpec",
16    "DataFrameReader",
17    "DataFrameWriter",
18]
class SparkSession:
 20class SparkSession:
 21    known_ids: t.ClassVar[t.Set[str]] = set()
 22    known_branch_ids: t.ClassVar[t.Set[str]] = set()
 23    known_sequence_ids: t.ClassVar[t.Set[str]] = set()
 24    name_to_sequence_id_mapping: t.ClassVar[t.Dict[str, t.List[str]]] = defaultdict(list)
 25
 26    def __init__(self):
 27        self.incrementing_id = 1
 28
 29    def __getattr__(self, name: str) -> SparkSession:
 30        return self
 31
 32    def __call__(self, *args, **kwargs) -> SparkSession:
 33        return self
 34
 35    @property
 36    def read(self) -> DataFrameReader:
 37        return DataFrameReader(self)
 38
 39    def table(self, tableName: str) -> DataFrame:
 40        return self.read.table(tableName)
 41
 42    def createDataFrame(
 43        self,
 44        data: t.Sequence[t.Union[t.Dict[str, ColumnLiterals], t.List[ColumnLiterals], t.Tuple]],
 45        schema: t.Optional[SchemaInput] = None,
 46        samplingRatio: t.Optional[float] = None,
 47        verifySchema: bool = False,
 48    ) -> DataFrame:
 49        from sqlglot.dataframe.sql.dataframe import DataFrame
 50
 51        if samplingRatio is not None or verifySchema:
 52            raise NotImplementedError("Sampling Ratio and Verify Schema are not supported")
 53        if schema is not None and (
 54            not isinstance(schema, (StructType, str, list))
 55            or (isinstance(schema, list) and not isinstance(schema[0], str))
 56        ):
 57            raise NotImplementedError("Only schema of either list or string of list supported")
 58        if not data:
 59            raise ValueError("Must provide data to create into a DataFrame")
 60
 61        column_mapping: t.Dict[str, t.Optional[str]]
 62        if schema is not None:
 63            column_mapping = get_column_mapping_from_schema_input(schema)
 64        elif isinstance(data[0], dict):
 65            column_mapping = {col_name.strip(): None for col_name in data[0]}
 66        else:
 67            column_mapping = {f"_{i}": None for i in range(1, len(data[0]) + 1)}
 68
 69        data_expressions = [
 70            exp.Tuple(
 71                expressions=list(
 72                    map(
 73                        lambda x: F.lit(x).expression,
 74                        row if not isinstance(row, dict) else row.values(),
 75                    )
 76                )
 77            )
 78            for row in data
 79        ]
 80
 81        sel_columns = [
 82            F.col(name).cast(data_type).alias(name).expression
 83            if data_type is not None
 84            else F.col(name).expression
 85            for name, data_type in column_mapping.items()
 86        ]
 87
 88        select_kwargs = {
 89            "expressions": sel_columns,
 90            "from": exp.From(
 91                expressions=[
 92                    exp.Values(
 93                        expressions=data_expressions,
 94                        alias=exp.TableAlias(
 95                            this=exp.to_identifier(self._auto_incrementing_name),
 96                            columns=[exp.to_identifier(col_name) for col_name in column_mapping],
 97                        ),
 98                    ),
 99                ],
100            ),
101        }
102
103        sel_expression = exp.Select(**select_kwargs)
104        return DataFrame(self, sel_expression)
105
106    def sql(self, sqlQuery: str) -> DataFrame:
107        expression = sqlglot.parse_one(sqlQuery, read="spark")
108        if isinstance(expression, exp.Select):
109            df = DataFrame(self, expression)
110            df = df._convert_leaf_to_cte()
111        elif isinstance(expression, (exp.Create, exp.Insert)):
112            select_expression = expression.expression.copy()
113            if isinstance(expression, exp.Insert):
114                select_expression.set("with", expression.args.get("with"))
115                expression.set("with", None)
116            del expression.args["expression"]
117            df = DataFrame(self, select_expression, output_expression_container=expression)  # type: ignore
118            df = df._convert_leaf_to_cte()
119        else:
120            raise ValueError(
121                "Unknown expression type provided in the SQL. Please create an issue with the SQL."
122            )
123        return df
124
125    @property
126    def _auto_incrementing_name(self) -> str:
127        name = f"a{self.incrementing_id}"
128        self.incrementing_id += 1
129        return name
130
131    @property
132    def _random_name(self) -> str:
133        return "r" + uuid.uuid4().hex
134
135    @property
136    def _random_branch_id(self) -> str:
137        id = self._random_id
138        self.known_branch_ids.add(id)
139        return id
140
141    @property
142    def _random_sequence_id(self):
143        id = self._random_id
144        self.known_sequence_ids.add(id)
145        return id
146
147    @property
148    def _random_id(self) -> str:
149        id = self._random_name
150        self.known_ids.add(id)
151        return id
152
153    @property
154    def _join_hint_names(self) -> t.Set[str]:
155        return {"BROADCAST", "MERGE", "SHUFFLE_HASH", "SHUFFLE_REPLICATE_NL"}
156
157    def _add_alias_to_mapping(self, name: str, sequence_id: str):
158        self.name_to_sequence_id_mapping[name].append(sequence_id)
def table(self, tableName: str) -> sqlglot.dataframe.sql.DataFrame:
39    def table(self, tableName: str) -> DataFrame:
40        return self.read.table(tableName)
def createDataFrame( self, data: Sequence[Union[Dict[str, <MagicMock id='140377708850400'>], List[<MagicMock id='140377708850400'>], Tuple]], schema: Optional[<MagicMock id='140377708766032'>] = None, samplingRatio: Optional[float] = None, verifySchema: bool = False) -> sqlglot.dataframe.sql.DataFrame:
 42    def createDataFrame(
 43        self,
 44        data: t.Sequence[t.Union[t.Dict[str, ColumnLiterals], t.List[ColumnLiterals], t.Tuple]],
 45        schema: t.Optional[SchemaInput] = None,
 46        samplingRatio: t.Optional[float] = None,
 47        verifySchema: bool = False,
 48    ) -> DataFrame:
 49        from sqlglot.dataframe.sql.dataframe import DataFrame
 50
 51        if samplingRatio is not None or verifySchema:
 52            raise NotImplementedError("Sampling Ratio and Verify Schema are not supported")
 53        if schema is not None and (
 54            not isinstance(schema, (StructType, str, list))
 55            or (isinstance(schema, list) and not isinstance(schema[0], str))
 56        ):
 57            raise NotImplementedError("Only schema of either list or string of list supported")
 58        if not data:
 59            raise ValueError("Must provide data to create into a DataFrame")
 60
 61        column_mapping: t.Dict[str, t.Optional[str]]
 62        if schema is not None:
 63            column_mapping = get_column_mapping_from_schema_input(schema)
 64        elif isinstance(data[0], dict):
 65            column_mapping = {col_name.strip(): None for col_name in data[0]}
 66        else:
 67            column_mapping = {f"_{i}": None for i in range(1, len(data[0]) + 1)}
 68
 69        data_expressions = [
 70            exp.Tuple(
 71                expressions=list(
 72                    map(
 73                        lambda x: F.lit(x).expression,
 74                        row if not isinstance(row, dict) else row.values(),
 75                    )
 76                )
 77            )
 78            for row in data
 79        ]
 80
 81        sel_columns = [
 82            F.col(name).cast(data_type).alias(name).expression
 83            if data_type is not None
 84            else F.col(name).expression
 85            for name, data_type in column_mapping.items()
 86        ]
 87
 88        select_kwargs = {
 89            "expressions": sel_columns,
 90            "from": exp.From(
 91                expressions=[
 92                    exp.Values(
 93                        expressions=data_expressions,
 94                        alias=exp.TableAlias(
 95                            this=exp.to_identifier(self._auto_incrementing_name),
 96                            columns=[exp.to_identifier(col_name) for col_name in column_mapping],
 97                        ),
 98                    ),
 99                ],
100            ),
101        }
102
103        sel_expression = exp.Select(**select_kwargs)
104        return DataFrame(self, sel_expression)
def sql(self, sqlQuery: str) -> sqlglot.dataframe.sql.DataFrame:
106    def sql(self, sqlQuery: str) -> DataFrame:
107        expression = sqlglot.parse_one(sqlQuery, read="spark")
108        if isinstance(expression, exp.Select):
109            df = DataFrame(self, expression)
110            df = df._convert_leaf_to_cte()
111        elif isinstance(expression, (exp.Create, exp.Insert)):
112            select_expression = expression.expression.copy()
113            if isinstance(expression, exp.Insert):
114                select_expression.set("with", expression.args.get("with"))
115                expression.set("with", None)
116            del expression.args["expression"]
117            df = DataFrame(self, select_expression, output_expression_container=expression)  # type: ignore
118            df = df._convert_leaf_to_cte()
119        else:
120            raise ValueError(
121                "Unknown expression type provided in the SQL. Please create an issue with the SQL."
122            )
123        return df
class DataFrame:
 45class DataFrame:
 46    def __init__(
 47        self,
 48        spark: SparkSession,
 49        expression: exp.Select,
 50        branch_id: t.Optional[str] = None,
 51        sequence_id: t.Optional[str] = None,
 52        last_op: Operation = Operation.INIT,
 53        pending_hints: t.Optional[t.List[exp.Expression]] = None,
 54        output_expression_container: t.Optional[OutputExpressionContainer] = None,
 55        **kwargs,
 56    ):
 57        self.spark = spark
 58        self.expression = expression
 59        self.branch_id = branch_id or self.spark._random_branch_id
 60        self.sequence_id = sequence_id or self.spark._random_sequence_id
 61        self.last_op = last_op
 62        self.pending_hints = pending_hints or []
 63        self.output_expression_container = output_expression_container or exp.Select()
 64
 65    def __getattr__(self, column_name: str) -> Column:
 66        return self[column_name]
 67
 68    def __getitem__(self, column_name: str) -> Column:
 69        column_name = f"{self.branch_id}.{column_name}"
 70        return Column(column_name)
 71
 72    def __copy__(self):
 73        return self.copy()
 74
 75    @property
 76    def sparkSession(self):
 77        return self.spark
 78
 79    @property
 80    def write(self):
 81        return DataFrameWriter(self)
 82
 83    @property
 84    def latest_cte_name(self) -> str:
 85        if not self.expression.ctes:
 86            from_exp = self.expression.args["from"]
 87            if from_exp.alias_or_name:
 88                return from_exp.alias_or_name
 89            table_alias = from_exp.find(exp.TableAlias)
 90            if not table_alias:
 91                raise RuntimeError(
 92                    f"Could not find an alias name for this expression: {self.expression}"
 93                )
 94            return table_alias.alias_or_name
 95        return self.expression.ctes[-1].alias
 96
 97    @property
 98    def pending_join_hints(self):
 99        return [hint for hint in self.pending_hints if isinstance(hint, exp.JoinHint)]
100
101    @property
102    def pending_partition_hints(self):
103        return [hint for hint in self.pending_hints if isinstance(hint, exp.Anonymous)]
104
105    @property
106    def columns(self) -> t.List[str]:
107        return self.expression.named_selects
108
109    @property
110    def na(self) -> DataFrameNaFunctions:
111        return DataFrameNaFunctions(self)
112
113    def _replace_cte_names_with_hashes(self, expression: exp.Select):
114        replacement_mapping = {}
115        for cte in expression.ctes:
116            old_name_id = cte.args["alias"].this
117            new_hashed_id = exp.to_identifier(
118                self._create_hash_from_expression(cte.this), quoted=old_name_id.args["quoted"]
119            )
120            replacement_mapping[old_name_id] = new_hashed_id
121            expression = expression.transform(replace_id_value, replacement_mapping)
122        return expression
123
124    def _create_cte_from_expression(
125        self,
126        expression: exp.Expression,
127        branch_id: t.Optional[str] = None,
128        sequence_id: t.Optional[str] = None,
129        **kwargs,
130    ) -> t.Tuple[exp.CTE, str]:
131        name = self.spark._random_name
132        expression_to_cte = expression.copy()
133        expression_to_cte.set("with", None)
134        cte = exp.Select().with_(name, as_=expression_to_cte, **kwargs).ctes[0]
135        cte.set("branch_id", branch_id or self.branch_id)
136        cte.set("sequence_id", sequence_id or self.sequence_id)
137        return cte, name
138
139    @t.overload
140    def _ensure_list_of_columns(self, cols: t.Collection[ColumnOrLiteral]) -> t.List[Column]:
141        ...
142
143    @t.overload
144    def _ensure_list_of_columns(self, cols: ColumnOrLiteral) -> t.List[Column]:
145        ...
146
147    def _ensure_list_of_columns(self, cols):
148        return Column.ensure_cols(ensure_list(cols))
149
150    def _ensure_and_normalize_cols(self, cols, expression: t.Optional[exp.Select] = None):
151        cols = self._ensure_list_of_columns(cols)
152        normalize(self.spark, expression or self.expression, cols)
153        return cols
154
155    def _ensure_and_normalize_col(self, col):
156        col = Column.ensure_col(col)
157        normalize(self.spark, self.expression, col)
158        return col
159
160    def _convert_leaf_to_cte(self, sequence_id: t.Optional[str] = None) -> DataFrame:
161        df = self._resolve_pending_hints()
162        sequence_id = sequence_id or df.sequence_id
163        expression = df.expression.copy()
164        cte_expression, cte_name = df._create_cte_from_expression(
165            expression=expression, sequence_id=sequence_id
166        )
167        new_expression = df._add_ctes_to_expression(
168            exp.Select(), expression.ctes + [cte_expression]
169        )
170        sel_columns = df._get_outer_select_columns(cte_expression)
171        new_expression = new_expression.from_(cte_name).select(
172            *[x.alias_or_name for x in sel_columns]
173        )
174        return df.copy(expression=new_expression, sequence_id=sequence_id)
175
176    def _resolve_pending_hints(self) -> DataFrame:
177        df = self.copy()
178        if not self.pending_hints:
179            return df
180        expression = df.expression
181        hint_expression = expression.args.get("hint") or exp.Hint(expressions=[])
182        for hint in df.pending_partition_hints:
183            hint_expression.append("expressions", hint)
184            df.pending_hints.remove(hint)
185
186        join_aliases = {
187            join_table.alias_or_name
188            for join_table in get_tables_from_expression_with_join(expression)
189        }
190        if join_aliases:
191            for hint in df.pending_join_hints:
192                for sequence_id_expression in hint.expressions:
193                    sequence_id_or_name = sequence_id_expression.alias_or_name
194                    sequence_ids_to_match = [sequence_id_or_name]
195                    if sequence_id_or_name in df.spark.name_to_sequence_id_mapping:
196                        sequence_ids_to_match = df.spark.name_to_sequence_id_mapping[
197                            sequence_id_or_name
198                        ]
199                    matching_ctes = [
200                        cte
201                        for cte in reversed(expression.ctes)
202                        if cte.args["sequence_id"] in sequence_ids_to_match
203                    ]
204                    for matching_cte in matching_ctes:
205                        if matching_cte.alias_or_name in join_aliases:
206                            sequence_id_expression.set("this", matching_cte.args["alias"].this)
207                            df.pending_hints.remove(hint)
208                            break
209                hint_expression.append("expressions", hint)
210        if hint_expression.expressions:
211            expression.set("hint", hint_expression)
212        return df
213
214    def _hint(self, hint_name: str, args: t.List[Column]) -> DataFrame:
215        hint_name = hint_name.upper()
216        hint_expression = (
217            exp.JoinHint(
218                this=hint_name,
219                expressions=[exp.to_table(parameter.alias_or_name) for parameter in args],
220            )
221            if hint_name in JOIN_HINTS
222            else exp.Anonymous(
223                this=hint_name, expressions=[parameter.expression for parameter in args]
224            )
225        )
226        new_df = self.copy()
227        new_df.pending_hints.append(hint_expression)
228        return new_df
229
230    def _set_operation(self, klass: t.Callable, other: DataFrame, distinct: bool):
231        other_df = other._convert_leaf_to_cte()
232        base_expression = self.expression.copy()
233        base_expression = self._add_ctes_to_expression(base_expression, other_df.expression.ctes)
234        all_ctes = base_expression.ctes
235        other_df.expression.set("with", None)
236        base_expression.set("with", None)
237        operation = klass(this=base_expression, distinct=distinct, expression=other_df.expression)
238        operation.set("with", exp.With(expressions=all_ctes))
239        return self.copy(expression=operation)._convert_leaf_to_cte()
240
241    def _cache(self, storage_level: str):
242        df = self._convert_leaf_to_cte()
243        df.expression.ctes[-1].set("cache_storage_level", storage_level)
244        return df
245
246    @classmethod
247    def _add_ctes_to_expression(cls, expression: exp.Select, ctes: t.List[exp.CTE]) -> exp.Select:
248        expression = expression.copy()
249        with_expression = expression.args.get("with")
250        if with_expression:
251            existing_ctes = with_expression.expressions
252            existsing_cte_names = {x.alias_or_name for x in existing_ctes}
253            for cte in ctes:
254                if cte.alias_or_name not in existsing_cte_names:
255                    existing_ctes.append(cte)
256        else:
257            existing_ctes = ctes
258        expression.set("with", exp.With(expressions=existing_ctes))
259        return expression
260
261    @classmethod
262    def _get_outer_select_columns(cls, item: t.Union[exp.Expression, DataFrame]) -> t.List[Column]:
263        expression = item.expression if isinstance(item, DataFrame) else item
264        return [Column(x) for x in (expression.find(exp.Select) or exp.Select()).expressions]
265
266    @classmethod
267    def _create_hash_from_expression(cls, expression: exp.Select):
268        value = expression.sql(dialect="spark").encode("utf-8")
269        return f"t{zlib.crc32(value)}"[:6]
270
271    def _get_select_expressions(
272        self,
273    ) -> t.List[t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select]]:
274        select_expressions: t.List[
275            t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select]
276        ] = []
277        main_select_ctes: t.List[exp.CTE] = []
278        for cte in self.expression.ctes:
279            cache_storage_level = cte.args.get("cache_storage_level")
280            if cache_storage_level:
281                select_expression = cte.this.copy()
282                select_expression.set("with", exp.With(expressions=copy(main_select_ctes)))
283                select_expression.set("cte_alias_name", cte.alias_or_name)
284                select_expression.set("cache_storage_level", cache_storage_level)
285                select_expressions.append((exp.Cache, select_expression))
286            else:
287                main_select_ctes.append(cte)
288        main_select = self.expression.copy()
289        if main_select_ctes:
290            main_select.set("with", exp.With(expressions=main_select_ctes))
291        expression_select_pair = (type(self.output_expression_container), main_select)
292        select_expressions.append(expression_select_pair)  # type: ignore
293        return select_expressions
294
295    def sql(self, dialect="spark", optimize=True, **kwargs) -> t.List[str]:
296        df = self._resolve_pending_hints()
297        select_expressions = df._get_select_expressions()
298        output_expressions: t.List[t.Union[exp.Select, exp.Cache, exp.Drop]] = []
299        replacement_mapping: t.Dict[exp.Identifier, exp.Identifier] = {}
300        for expression_type, select_expression in select_expressions:
301            select_expression = select_expression.transform(replace_id_value, replacement_mapping)
302            if optimize:
303                select_expression = optimize_func(select_expression, identify="always")
304            select_expression = df._replace_cte_names_with_hashes(select_expression)
305            expression: t.Union[exp.Select, exp.Cache, exp.Drop]
306            if expression_type == exp.Cache:
307                cache_table_name = df._create_hash_from_expression(select_expression)
308                cache_table = exp.to_table(cache_table_name)
309                original_alias_name = select_expression.args["cte_alias_name"]
310
311                replacement_mapping[exp.to_identifier(original_alias_name)] = exp.to_identifier(  # type: ignore
312                    cache_table_name
313                )
314                sqlglot.schema.add_table(
315                    cache_table_name,
316                    {
317                        expression.alias_or_name: expression.type.sql("spark")
318                        for expression in select_expression.expressions
319                    },
320                )
321                cache_storage_level = select_expression.args["cache_storage_level"]
322                options = [
323                    exp.Literal.string("storageLevel"),
324                    exp.Literal.string(cache_storage_level),
325                ]
326                expression = exp.Cache(
327                    this=cache_table, expression=select_expression, lazy=True, options=options
328                )
329                # We will drop the "view" if it exists before running the cache table
330                output_expressions.append(exp.Drop(this=cache_table, exists=True, kind="VIEW"))
331            elif expression_type == exp.Create:
332                expression = df.output_expression_container.copy()
333                expression.set("expression", select_expression)
334            elif expression_type == exp.Insert:
335                expression = df.output_expression_container.copy()
336                select_without_ctes = select_expression.copy()
337                select_without_ctes.set("with", None)
338                expression.set("expression", select_without_ctes)
339                if select_expression.ctes:
340                    expression.set("with", exp.With(expressions=select_expression.ctes))
341            elif expression_type == exp.Select:
342                expression = select_expression
343            else:
344                raise ValueError(f"Invalid expression type: {expression_type}")
345            output_expressions.append(expression)
346
347        return [
348            expression.sql(**{"dialect": dialect, **kwargs}) for expression in output_expressions
349        ]
350
351    def copy(self, **kwargs) -> DataFrame:
352        return DataFrame(**object_to_dict(self, **kwargs))
353
354    @operation(Operation.SELECT)
355    def select(self, *cols, **kwargs) -> DataFrame:
356        cols = self._ensure_and_normalize_cols(cols)
357        kwargs["append"] = kwargs.get("append", False)
358        if self.expression.args.get("joins"):
359            ambiguous_cols = [
360                col
361                for col in cols
362                if isinstance(col.column_expression, exp.Column) and not col.column_expression.table
363            ]
364            if ambiguous_cols:
365                join_table_identifiers = [
366                    x.this for x in get_tables_from_expression_with_join(self.expression)
367                ]
368                cte_names_in_join = [x.this for x in join_table_identifiers]
369                # If we have columns that resolve to multiple CTE expressions then we want to use each CTE left-to-right
370                # and therefore we allow multiple columns with the same name in the result. This matches the behavior
371                # of Spark.
372                resolved_column_position: t.Dict[Column, int] = {col: -1 for col in ambiguous_cols}
373                for ambiguous_col in ambiguous_cols:
374                    ctes_with_column = [
375                        cte
376                        for cte in self.expression.ctes
377                        if cte.alias_or_name in cte_names_in_join
378                        and ambiguous_col.alias_or_name in cte.this.named_selects
379                    ]
380                    # Check if there is a CTE with this column that we haven't used before. If so, use it. Otherwise,
381                    # use the same CTE we used before
382                    cte = seq_get(ctes_with_column, resolved_column_position[ambiguous_col] + 1)
383                    if cte:
384                        resolved_column_position[ambiguous_col] += 1
385                    else:
386                        cte = ctes_with_column[resolved_column_position[ambiguous_col]]
387                    ambiguous_col.expression.set("table", cte.alias_or_name)
388        return self.copy(
389            expression=self.expression.select(*[x.expression for x in cols], **kwargs), **kwargs
390        )
391
392    @operation(Operation.NO_OP)
393    def alias(self, name: str, **kwargs) -> DataFrame:
394        new_sequence_id = self.spark._random_sequence_id
395        df = self.copy()
396        for join_hint in df.pending_join_hints:
397            for expression in join_hint.expressions:
398                if expression.alias_or_name == self.sequence_id:
399                    expression.set("this", Column.ensure_col(new_sequence_id).expression)
400        df.spark._add_alias_to_mapping(name, new_sequence_id)
401        return df._convert_leaf_to_cte(sequence_id=new_sequence_id)
402
403    @operation(Operation.WHERE)
404    def where(self, column: t.Union[Column, bool], **kwargs) -> DataFrame:
405        col = self._ensure_and_normalize_col(column)
406        return self.copy(expression=self.expression.where(col.expression))
407
408    filter = where
409
410    @operation(Operation.GROUP_BY)
411    def groupBy(self, *cols, **kwargs) -> GroupedData:
412        columns = self._ensure_and_normalize_cols(cols)
413        return GroupedData(self, columns, self.last_op)
414
415    @operation(Operation.SELECT)
416    def agg(self, *exprs, **kwargs) -> DataFrame:
417        cols = self._ensure_and_normalize_cols(exprs)
418        return self.groupBy().agg(*cols)
419
420    @operation(Operation.FROM)
421    def join(
422        self,
423        other_df: DataFrame,
424        on: t.Union[str, t.List[str], Column, t.List[Column]],
425        how: str = "inner",
426        **kwargs,
427    ) -> DataFrame:
428        other_df = other_df._convert_leaf_to_cte()
429        join_columns = self._ensure_list_of_columns(on)
430        # We will determine actual "join on" expression later so we don't provide it at first
431        join_expression = self.expression.join(
432            other_df.latest_cte_name, join_type=how.replace("_", " ")
433        )
434        join_expression = self._add_ctes_to_expression(join_expression, other_df.expression.ctes)
435        self_columns = self._get_outer_select_columns(join_expression)
436        other_columns = self._get_outer_select_columns(other_df)
437        # Determines the join clause and select columns to be used passed on what type of columns were provided for
438        # the join. The columns returned changes based on how the on expression is provided.
439        if isinstance(join_columns[0].expression, exp.Column):
440            """
441            Unique characteristics of join on column names only:
442            * The column names are put at the front of the select list
443            * The column names are deduplicated across the entire select list and only the column names (other dups are allowed)
444            """
445            table_names = [
446                table.alias_or_name
447                for table in get_tables_from_expression_with_join(join_expression)
448            ]
449            potential_ctes = [
450                cte
451                for cte in join_expression.ctes
452                if cte.alias_or_name in table_names
453                and cte.alias_or_name != other_df.latest_cte_name
454            ]
455            # Determine the table to reference for the left side of the join by checking each of the left side
456            # tables and see if they have the column being referenced.
457            join_column_pairs = []
458            for join_column in join_columns:
459                num_matching_ctes = 0
460                for cte in potential_ctes:
461                    if join_column.alias_or_name in cte.this.named_selects:
462                        left_column = join_column.copy().set_table_name(cte.alias_or_name)
463                        right_column = join_column.copy().set_table_name(other_df.latest_cte_name)
464                        join_column_pairs.append((left_column, right_column))
465                        num_matching_ctes += 1
466                if num_matching_ctes > 1:
467                    raise ValueError(
468                        f"Column {join_column.alias_or_name} is ambiguous. Please specify the table name."
469                    )
470                elif num_matching_ctes == 0:
471                    raise ValueError(
472                        f"Column {join_column.alias_or_name} does not exist in any of the tables."
473                    )
474            join_clause = functools.reduce(
475                lambda x, y: x & y,
476                [left_column == right_column for left_column, right_column in join_column_pairs],
477            )
478            join_column_names = [left_col.alias_or_name for left_col, _ in join_column_pairs]
479            # To match spark behavior only the join clause gets deduplicated and it gets put in the front of the column list
480            select_column_names = [
481                column.alias_or_name
482                if not isinstance(column.expression.this, exp.Star)
483                else column.sql()
484                for column in self_columns + other_columns
485            ]
486            select_column_names = [
487                column_name
488                for column_name in select_column_names
489                if column_name not in join_column_names
490            ]
491            select_column_names = join_column_names + select_column_names
492        else:
493            """
494            Unique characteristics of join on expressions:
495            * There is no deduplication of the results.
496            * The left join dataframe columns go first and right come after. No sort preference is given to join columns
497            """
498            join_columns = self._ensure_and_normalize_cols(join_columns, join_expression)
499            if len(join_columns) > 1:
500                join_columns = [functools.reduce(lambda x, y: x & y, join_columns)]
501            join_clause = join_columns[0]
502            select_column_names = [column.alias_or_name for column in self_columns + other_columns]
503
504        # Update the on expression with the actual join clause to replace the dummy one from before
505        join_expression.args["joins"][-1].set("on", join_clause.expression)
506        new_df = self.copy(expression=join_expression)
507        new_df.pending_join_hints.extend(self.pending_join_hints)
508        new_df.pending_hints.extend(other_df.pending_hints)
509        new_df = new_df.select.__wrapped__(new_df, *select_column_names)
510        return new_df
511
512    @operation(Operation.ORDER_BY)
513    def orderBy(
514        self,
515        *cols: t.Union[str, Column],
516        ascending: t.Optional[t.Union[t.Any, t.List[t.Any]]] = None,
517    ) -> DataFrame:
518        """
519        This implementation lets any ordered columns take priority over whatever is provided in `ascending`. Spark
520        has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this
521        is unlikely to come up.
522        """
523        columns = self._ensure_and_normalize_cols(cols)
524        pre_ordered_col_indexes = [
525            x
526            for x in [
527                i if isinstance(col.expression, exp.Ordered) else None
528                for i, col in enumerate(columns)
529            ]
530            if x is not None
531        ]
532        if ascending is None:
533            ascending = [True] * len(columns)
534        elif not isinstance(ascending, list):
535            ascending = [ascending] * len(columns)
536        ascending = [bool(x) for i, x in enumerate(ascending)]
537        assert len(columns) == len(
538            ascending
539        ), "The length of items in ascending must equal the number of columns provided"
540        col_and_ascending = list(zip(columns, ascending))
541        order_by_columns = [
542            exp.Ordered(this=col.expression, desc=not asc)
543            if i not in pre_ordered_col_indexes
544            else columns[i].column_expression
545            for i, (col, asc) in enumerate(col_and_ascending)
546        ]
547        return self.copy(expression=self.expression.order_by(*order_by_columns))
548
549    sort = orderBy
550
551    @operation(Operation.FROM)
552    def union(self, other: DataFrame) -> DataFrame:
553        return self._set_operation(exp.Union, other, False)
554
555    unionAll = union
556
557    @operation(Operation.FROM)
558    def unionByName(self, other: DataFrame, allowMissingColumns: bool = False):
559        l_columns = self.columns
560        r_columns = other.columns
561        if not allowMissingColumns:
562            l_expressions = l_columns
563            r_expressions = l_columns
564        else:
565            l_expressions = []
566            r_expressions = []
567            r_columns_unused = copy(r_columns)
568            for l_column in l_columns:
569                l_expressions.append(l_column)
570                if l_column in r_columns:
571                    r_expressions.append(l_column)
572                    r_columns_unused.remove(l_column)
573                else:
574                    r_expressions.append(exp.alias_(exp.Null(), l_column))
575            for r_column in r_columns_unused:
576                l_expressions.append(exp.alias_(exp.Null(), r_column))
577                r_expressions.append(r_column)
578        r_df = (
579            other.copy()._convert_leaf_to_cte().select(*self._ensure_list_of_columns(r_expressions))
580        )
581        l_df = self.copy()
582        if allowMissingColumns:
583            l_df = l_df._convert_leaf_to_cte().select(*self._ensure_list_of_columns(l_expressions))
584        return l_df._set_operation(exp.Union, r_df, False)
585
586    @operation(Operation.FROM)
587    def intersect(self, other: DataFrame) -> DataFrame:
588        return self._set_operation(exp.Intersect, other, True)
589
590    @operation(Operation.FROM)
591    def intersectAll(self, other: DataFrame) -> DataFrame:
592        return self._set_operation(exp.Intersect, other, False)
593
594    @operation(Operation.FROM)
595    def exceptAll(self, other: DataFrame) -> DataFrame:
596        return self._set_operation(exp.Except, other, False)
597
598    @operation(Operation.SELECT)
599    def distinct(self) -> DataFrame:
600        return self.copy(expression=self.expression.distinct())
601
602    @operation(Operation.SELECT)
603    def dropDuplicates(self, subset: t.Optional[t.List[str]] = None):
604        if not subset:
605            return self.distinct()
606        column_names = ensure_list(subset)
607        window = Window.partitionBy(*column_names).orderBy(*column_names)
608        return (
609            self.copy()
610            .withColumn("row_num", F.row_number().over(window))
611            .where(F.col("row_num") == F.lit(1))
612            .drop("row_num")
613        )
614
615    @operation(Operation.FROM)
616    def dropna(
617        self,
618        how: str = "any",
619        thresh: t.Optional[int] = None,
620        subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
621    ) -> DataFrame:
622        minimum_non_null = thresh or 0  # will be determined later if thresh is null
623        new_df = self.copy()
624        all_columns = self._get_outer_select_columns(new_df.expression)
625        if subset:
626            null_check_columns = self._ensure_and_normalize_cols(subset)
627        else:
628            null_check_columns = all_columns
629        if thresh is None:
630            minimum_num_nulls = 1 if how == "any" else len(null_check_columns)
631        else:
632            minimum_num_nulls = len(null_check_columns) - minimum_non_null + 1
633        if minimum_num_nulls > len(null_check_columns):
634            raise RuntimeError(
635                f"The minimum num nulls for dropna must be less than or equal to the number of columns. "
636                f"Minimum num nulls: {minimum_num_nulls}, Num Columns: {len(null_check_columns)}"
637            )
638        if_null_checks = [
639            F.when(column.isNull(), F.lit(1)).otherwise(F.lit(0)) for column in null_check_columns
640        ]
641        nulls_added_together = functools.reduce(lambda x, y: x + y, if_null_checks)
642        num_nulls = nulls_added_together.alias("num_nulls")
643        new_df = new_df.select(num_nulls, append=True)
644        filtered_df = new_df.where(F.col("num_nulls") < F.lit(minimum_num_nulls))
645        final_df = filtered_df.select(*all_columns)
646        return final_df
647
648    @operation(Operation.FROM)
649    def fillna(
650        self,
651        value: t.Union[ColumnLiterals],
652        subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
653    ) -> DataFrame:
654        """
655        Functionality Difference: If you provide a value to replace a null and that type conflicts
656        with the type of the column then PySpark will just ignore your replacement.
657        This will try to cast them to be the same in some cases. So they won't always match.
658        Best to not mix types so make sure replacement is the same type as the column
659
660        Possibility for improvement: Use `typeof` function to get the type of the column
661        and check if it matches the type of the value provided. If not then make it null.
662        """
663        from sqlglot.dataframe.sql.functions import lit
664
665        values = None
666        columns = None
667        new_df = self.copy()
668        all_columns = self._get_outer_select_columns(new_df.expression)
669        all_column_mapping = {column.alias_or_name: column for column in all_columns}
670        if isinstance(value, dict):
671            values = list(value.values())
672            columns = self._ensure_and_normalize_cols(list(value))
673        if not columns:
674            columns = self._ensure_and_normalize_cols(subset) if subset else all_columns
675        if not values:
676            values = [value] * len(columns)
677        value_columns = [lit(value) for value in values]
678
679        null_replacement_mapping = {
680            column.alias_or_name: (
681                F.when(column.isNull(), value).otherwise(column).alias(column.alias_or_name)
682            )
683            for column, value in zip(columns, value_columns)
684        }
685        null_replacement_mapping = {**all_column_mapping, **null_replacement_mapping}
686        null_replacement_columns = [
687            null_replacement_mapping[column.alias_or_name] for column in all_columns
688        ]
689        new_df = new_df.select(*null_replacement_columns)
690        return new_df
691
692    @operation(Operation.FROM)
693    def replace(
694        self,
695        to_replace: t.Union[bool, int, float, str, t.List, t.Dict],
696        value: t.Optional[t.Union[bool, int, float, str, t.List]] = None,
697        subset: t.Optional[t.Collection[ColumnOrName] | ColumnOrName] = None,
698    ) -> DataFrame:
699        from sqlglot.dataframe.sql.functions import lit
700
701        old_values = None
702        new_df = self.copy()
703        all_columns = self._get_outer_select_columns(new_df.expression)
704        all_column_mapping = {column.alias_or_name: column for column in all_columns}
705
706        columns = self._ensure_and_normalize_cols(subset) if subset else all_columns
707        if isinstance(to_replace, dict):
708            old_values = list(to_replace)
709            new_values = list(to_replace.values())
710        elif not old_values and isinstance(to_replace, list):
711            assert isinstance(value, list), "value must be a list since the replacements are a list"
712            assert len(to_replace) == len(
713                value
714            ), "the replacements and values must be the same length"
715            old_values = to_replace
716            new_values = value
717        else:
718            old_values = [to_replace] * len(columns)
719            new_values = [value] * len(columns)
720        old_values = [lit(value) for value in old_values]
721        new_values = [lit(value) for value in new_values]
722
723        replacement_mapping = {}
724        for column in columns:
725            expression = Column(None)
726            for i, (old_value, new_value) in enumerate(zip(old_values, new_values)):
727                if i == 0:
728                    expression = F.when(column == old_value, new_value)
729                else:
730                    expression = expression.when(column == old_value, new_value)  # type: ignore
731            replacement_mapping[column.alias_or_name] = expression.otherwise(column).alias(
732                column.expression.alias_or_name
733            )
734
735        replacement_mapping = {**all_column_mapping, **replacement_mapping}
736        replacement_columns = [replacement_mapping[column.alias_or_name] for column in all_columns]
737        new_df = new_df.select(*replacement_columns)
738        return new_df
739
740    @operation(Operation.SELECT)
741    def withColumn(self, colName: str, col: Column) -> DataFrame:
742        col = self._ensure_and_normalize_col(col)
743        existing_col_names = self.expression.named_selects
744        existing_col_index = (
745            existing_col_names.index(colName) if colName in existing_col_names else None
746        )
747        if existing_col_index:
748            expression = self.expression.copy()
749            expression.expressions[existing_col_index] = col.expression
750            return self.copy(expression=expression)
751        return self.copy().select(col.alias(colName), append=True)
752
753    @operation(Operation.SELECT)
754    def withColumnRenamed(self, existing: str, new: str):
755        expression = self.expression.copy()
756        existing_columns = [
757            expression
758            for expression in expression.expressions
759            if expression.alias_or_name == existing
760        ]
761        if not existing_columns:
762            raise ValueError("Tried to rename a column that doesn't exist")
763        for existing_column in existing_columns:
764            if isinstance(existing_column, exp.Column):
765                existing_column.replace(exp.alias_(existing_column.copy(), new))
766            else:
767                existing_column.set("alias", exp.to_identifier(new))
768        return self.copy(expression=expression)
769
770    @operation(Operation.SELECT)
771    def drop(self, *cols: t.Union[str, Column]) -> DataFrame:
772        all_columns = self._get_outer_select_columns(self.expression)
773        drop_cols = self._ensure_and_normalize_cols(cols)
774        new_columns = [
775            col
776            for col in all_columns
777            if col.alias_or_name not in [drop_column.alias_or_name for drop_column in drop_cols]
778        ]
779        return self.copy().select(*new_columns, append=False)
780
781    @operation(Operation.LIMIT)
782    def limit(self, num: int) -> DataFrame:
783        return self.copy(expression=self.expression.limit(num))
784
785    @operation(Operation.NO_OP)
786    def hint(self, name: str, *parameters: t.Optional[t.Union[str, int]]) -> DataFrame:
787        parameter_list = ensure_list(parameters)
788        parameter_columns = (
789            self._ensure_list_of_columns(parameter_list)
790            if parameters
791            else Column.ensure_cols([self.sequence_id])
792        )
793        return self._hint(name, parameter_columns)
794
795    @operation(Operation.NO_OP)
796    def repartition(
797        self, numPartitions: t.Union[int, ColumnOrName], *cols: ColumnOrName
798    ) -> DataFrame:
799        num_partition_cols = self._ensure_list_of_columns(numPartitions)
800        columns = self._ensure_and_normalize_cols(cols)
801        args = num_partition_cols + columns
802        return self._hint("repartition", args)
803
804    @operation(Operation.NO_OP)
805    def coalesce(self, numPartitions: int) -> DataFrame:
806        num_partitions = Column.ensure_cols([numPartitions])
807        return self._hint("coalesce", num_partitions)
808
809    @operation(Operation.NO_OP)
810    def cache(self) -> DataFrame:
811        return self._cache(storage_level="MEMORY_AND_DISK")
812
813    @operation(Operation.NO_OP)
814    def persist(self, storageLevel: str = "MEMORY_AND_DISK_SER") -> DataFrame:
815        """
816        Storage Level Options: https://spark.apache.org/docs/3.0.0-preview/sql-ref-syntax-aux-cache-cache-table.html
817        """
818        return self._cache(storageLevel)
DataFrame( spark: <MagicMock id='140377712280432'>, expression: sqlglot.expressions.Select, branch_id: Optional[str] = None, sequence_id: Optional[str] = None, last_op: sqlglot.dataframe.sql.operations.Operation = <Operation.INIT: -1>, pending_hints: Optional[List[sqlglot.expressions.Expression]] = None, output_expression_container: Optional[<MagicMock id='140377711871072'>] = None, **kwargs)
46    def __init__(
47        self,
48        spark: SparkSession,
49        expression: exp.Select,
50        branch_id: t.Optional[str] = None,
51        sequence_id: t.Optional[str] = None,
52        last_op: Operation = Operation.INIT,
53        pending_hints: t.Optional[t.List[exp.Expression]] = None,
54        output_expression_container: t.Optional[OutputExpressionContainer] = None,
55        **kwargs,
56    ):
57        self.spark = spark
58        self.expression = expression
59        self.branch_id = branch_id or self.spark._random_branch_id
60        self.sequence_id = sequence_id or self.spark._random_sequence_id
61        self.last_op = last_op
62        self.pending_hints = pending_hints or []
63        self.output_expression_container = output_expression_container or exp.Select()
def sql(self, dialect='spark', optimize=True, **kwargs) -> List[str]:
295    def sql(self, dialect="spark", optimize=True, **kwargs) -> t.List[str]:
296        df = self._resolve_pending_hints()
297        select_expressions = df._get_select_expressions()
298        output_expressions: t.List[t.Union[exp.Select, exp.Cache, exp.Drop]] = []
299        replacement_mapping: t.Dict[exp.Identifier, exp.Identifier] = {}
300        for expression_type, select_expression in select_expressions:
301            select_expression = select_expression.transform(replace_id_value, replacement_mapping)
302            if optimize:
303                select_expression = optimize_func(select_expression, identify="always")
304            select_expression = df._replace_cte_names_with_hashes(select_expression)
305            expression: t.Union[exp.Select, exp.Cache, exp.Drop]
306            if expression_type == exp.Cache:
307                cache_table_name = df._create_hash_from_expression(select_expression)
308                cache_table = exp.to_table(cache_table_name)
309                original_alias_name = select_expression.args["cte_alias_name"]
310
311                replacement_mapping[exp.to_identifier(original_alias_name)] = exp.to_identifier(  # type: ignore
312                    cache_table_name
313                )
314                sqlglot.schema.add_table(
315                    cache_table_name,
316                    {
317                        expression.alias_or_name: expression.type.sql("spark")
318                        for expression in select_expression.expressions
319                    },
320                )
321                cache_storage_level = select_expression.args["cache_storage_level"]
322                options = [
323                    exp.Literal.string("storageLevel"),
324                    exp.Literal.string(cache_storage_level),
325                ]
326                expression = exp.Cache(
327                    this=cache_table, expression=select_expression, lazy=True, options=options
328                )
329                # We will drop the "view" if it exists before running the cache table
330                output_expressions.append(exp.Drop(this=cache_table, exists=True, kind="VIEW"))
331            elif expression_type == exp.Create:
332                expression = df.output_expression_container.copy()
333                expression.set("expression", select_expression)
334            elif expression_type == exp.Insert:
335                expression = df.output_expression_container.copy()
336                select_without_ctes = select_expression.copy()
337                select_without_ctes.set("with", None)
338                expression.set("expression", select_without_ctes)
339                if select_expression.ctes:
340                    expression.set("with", exp.With(expressions=select_expression.ctes))
341            elif expression_type == exp.Select:
342                expression = select_expression
343            else:
344                raise ValueError(f"Invalid expression type: {expression_type}")
345            output_expressions.append(expression)
346
347        return [
348            expression.sql(**{"dialect": dialect, **kwargs}) for expression in output_expressions
349        ]
def copy(self, **kwargs) -> sqlglot.dataframe.sql.DataFrame:
351    def copy(self, **kwargs) -> DataFrame:
352        return DataFrame(**object_to_dict(self, **kwargs))
@operation(Operation.SELECT)
def select(self, *cols, **kwargs) -> sqlglot.dataframe.sql.DataFrame:
354    @operation(Operation.SELECT)
355    def select(self, *cols, **kwargs) -> DataFrame:
356        cols = self._ensure_and_normalize_cols(cols)
357        kwargs["append"] = kwargs.get("append", False)
358        if self.expression.args.get("joins"):
359            ambiguous_cols = [
360                col
361                for col in cols
362                if isinstance(col.column_expression, exp.Column) and not col.column_expression.table
363            ]
364            if ambiguous_cols:
365                join_table_identifiers = [
366                    x.this for x in get_tables_from_expression_with_join(self.expression)
367                ]
368                cte_names_in_join = [x.this for x in join_table_identifiers]
369                # If we have columns that resolve to multiple CTE expressions then we want to use each CTE left-to-right
370                # and therefore we allow multiple columns with the same name in the result. This matches the behavior
371                # of Spark.
372                resolved_column_position: t.Dict[Column, int] = {col: -1 for col in ambiguous_cols}
373                for ambiguous_col in ambiguous_cols:
374                    ctes_with_column = [
375                        cte
376                        for cte in self.expression.ctes
377                        if cte.alias_or_name in cte_names_in_join
378                        and ambiguous_col.alias_or_name in cte.this.named_selects
379                    ]
380                    # Check if there is a CTE with this column that we haven't used before. If so, use it. Otherwise,
381                    # use the same CTE we used before
382                    cte = seq_get(ctes_with_column, resolved_column_position[ambiguous_col] + 1)
383                    if cte:
384                        resolved_column_position[ambiguous_col] += 1
385                    else:
386                        cte = ctes_with_column[resolved_column_position[ambiguous_col]]
387                    ambiguous_col.expression.set("table", cte.alias_or_name)
388        return self.copy(
389            expression=self.expression.select(*[x.expression for x in cols], **kwargs), **kwargs
390        )
@operation(Operation.NO_OP)
def alias(self, name: str, **kwargs) -> sqlglot.dataframe.sql.DataFrame:
392    @operation(Operation.NO_OP)
393    def alias(self, name: str, **kwargs) -> DataFrame:
394        new_sequence_id = self.spark._random_sequence_id
395        df = self.copy()
396        for join_hint in df.pending_join_hints:
397            for expression in join_hint.expressions:
398                if expression.alias_or_name == self.sequence_id:
399                    expression.set("this", Column.ensure_col(new_sequence_id).expression)
400        df.spark._add_alias_to_mapping(name, new_sequence_id)
401        return df._convert_leaf_to_cte(sequence_id=new_sequence_id)
@operation(Operation.WHERE)
def where( self, column: Union[sqlglot.dataframe.sql.Column, bool], **kwargs) -> sqlglot.dataframe.sql.DataFrame:
403    @operation(Operation.WHERE)
404    def where(self, column: t.Union[Column, bool], **kwargs) -> DataFrame:
405        col = self._ensure_and_normalize_col(column)
406        return self.copy(expression=self.expression.where(col.expression))
@operation(Operation.WHERE)
def filter( self, column: Union[sqlglot.dataframe.sql.Column, bool], **kwargs) -> sqlglot.dataframe.sql.DataFrame:
403    @operation(Operation.WHERE)
404    def where(self, column: t.Union[Column, bool], **kwargs) -> DataFrame:
405        col = self._ensure_and_normalize_col(column)
406        return self.copy(expression=self.expression.where(col.expression))
@operation(Operation.GROUP_BY)
def groupBy(self, *cols, **kwargs) -> sqlglot.dataframe.sql.GroupedData:
410    @operation(Operation.GROUP_BY)
411    def groupBy(self, *cols, **kwargs) -> GroupedData:
412        columns = self._ensure_and_normalize_cols(cols)
413        return GroupedData(self, columns, self.last_op)
@operation(Operation.SELECT)
def agg(self, *exprs, **kwargs) -> sqlglot.dataframe.sql.DataFrame:
415    @operation(Operation.SELECT)
416    def agg(self, *exprs, **kwargs) -> DataFrame:
417        cols = self._ensure_and_normalize_cols(exprs)
418        return self.groupBy().agg(*cols)
@operation(Operation.FROM)
def join( self, other_df: sqlglot.dataframe.sql.DataFrame, on: Union[str, List[str], sqlglot.dataframe.sql.Column, List[sqlglot.dataframe.sql.Column]], how: str = 'inner', **kwargs) -> sqlglot.dataframe.sql.DataFrame:
420    @operation(Operation.FROM)
421    def join(
422        self,
423        other_df: DataFrame,
424        on: t.Union[str, t.List[str], Column, t.List[Column]],
425        how: str = "inner",
426        **kwargs,
427    ) -> DataFrame:
428        other_df = other_df._convert_leaf_to_cte()
429        join_columns = self._ensure_list_of_columns(on)
430        # We will determine actual "join on" expression later so we don't provide it at first
431        join_expression = self.expression.join(
432            other_df.latest_cte_name, join_type=how.replace("_", " ")
433        )
434        join_expression = self._add_ctes_to_expression(join_expression, other_df.expression.ctes)
435        self_columns = self._get_outer_select_columns(join_expression)
436        other_columns = self._get_outer_select_columns(other_df)
437        # Determines the join clause and select columns to be used passed on what type of columns were provided for
438        # the join. The columns returned changes based on how the on expression is provided.
439        if isinstance(join_columns[0].expression, exp.Column):
440            """
441            Unique characteristics of join on column names only:
442            * The column names are put at the front of the select list
443            * The column names are deduplicated across the entire select list and only the column names (other dups are allowed)
444            """
445            table_names = [
446                table.alias_or_name
447                for table in get_tables_from_expression_with_join(join_expression)
448            ]
449            potential_ctes = [
450                cte
451                for cte in join_expression.ctes
452                if cte.alias_or_name in table_names
453                and cte.alias_or_name != other_df.latest_cte_name
454            ]
455            # Determine the table to reference for the left side of the join by checking each of the left side
456            # tables and see if they have the column being referenced.
457            join_column_pairs = []
458            for join_column in join_columns:
459                num_matching_ctes = 0
460                for cte in potential_ctes:
461                    if join_column.alias_or_name in cte.this.named_selects:
462                        left_column = join_column.copy().set_table_name(cte.alias_or_name)
463                        right_column = join_column.copy().set_table_name(other_df.latest_cte_name)
464                        join_column_pairs.append((left_column, right_column))
465                        num_matching_ctes += 1
466                if num_matching_ctes > 1:
467                    raise ValueError(
468                        f"Column {join_column.alias_or_name} is ambiguous. Please specify the table name."
469                    )
470                elif num_matching_ctes == 0:
471                    raise ValueError(
472                        f"Column {join_column.alias_or_name} does not exist in any of the tables."
473                    )
474            join_clause = functools.reduce(
475                lambda x, y: x & y,
476                [left_column == right_column for left_column, right_column in join_column_pairs],
477            )
478            join_column_names = [left_col.alias_or_name for left_col, _ in join_column_pairs]
479            # To match spark behavior only the join clause gets deduplicated and it gets put in the front of the column list
480            select_column_names = [
481                column.alias_or_name
482                if not isinstance(column.expression.this, exp.Star)
483                else column.sql()
484                for column in self_columns + other_columns
485            ]
486            select_column_names = [
487                column_name
488                for column_name in select_column_names
489                if column_name not in join_column_names
490            ]
491            select_column_names = join_column_names + select_column_names
492        else:
493            """
494            Unique characteristics of join on expressions:
495            * There is no deduplication of the results.
496            * The left join dataframe columns go first and right come after. No sort preference is given to join columns
497            """
498            join_columns = self._ensure_and_normalize_cols(join_columns, join_expression)
499            if len(join_columns) > 1:
500                join_columns = [functools.reduce(lambda x, y: x & y, join_columns)]
501            join_clause = join_columns[0]
502            select_column_names = [column.alias_or_name for column in self_columns + other_columns]
503
504        # Update the on expression with the actual join clause to replace the dummy one from before
505        join_expression.args["joins"][-1].set("on", join_clause.expression)
506        new_df = self.copy(expression=join_expression)
507        new_df.pending_join_hints.extend(self.pending_join_hints)
508        new_df.pending_hints.extend(other_df.pending_hints)
509        new_df = new_df.select.__wrapped__(new_df, *select_column_names)
510        return new_df
@operation(Operation.ORDER_BY)
def orderBy( self, *cols: Union[str, sqlglot.dataframe.sql.Column], ascending: Union[Any, List[Any], NoneType] = None) -> sqlglot.dataframe.sql.DataFrame:
512    @operation(Operation.ORDER_BY)
513    def orderBy(
514        self,
515        *cols: t.Union[str, Column],
516        ascending: t.Optional[t.Union[t.Any, t.List[t.Any]]] = None,
517    ) -> DataFrame:
518        """
519        This implementation lets any ordered columns take priority over whatever is provided in `ascending`. Spark
520        has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this
521        is unlikely to come up.
522        """
523        columns = self._ensure_and_normalize_cols(cols)
524        pre_ordered_col_indexes = [
525            x
526            for x in [
527                i if isinstance(col.expression, exp.Ordered) else None
528                for i, col in enumerate(columns)
529            ]
530            if x is not None
531        ]
532        if ascending is None:
533            ascending = [True] * len(columns)
534        elif not isinstance(ascending, list):
535            ascending = [ascending] * len(columns)
536        ascending = [bool(x) for i, x in enumerate(ascending)]
537        assert len(columns) == len(
538            ascending
539        ), "The length of items in ascending must equal the number of columns provided"
540        col_and_ascending = list(zip(columns, ascending))
541        order_by_columns = [
542            exp.Ordered(this=col.expression, desc=not asc)
543            if i not in pre_ordered_col_indexes
544            else columns[i].column_expression
545            for i, (col, asc) in enumerate(col_and_ascending)
546        ]
547        return self.copy(expression=self.expression.order_by(*order_by_columns))

This implementation lets any ordered columns take priority over whatever is provided in ascending. Spark has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this is unlikely to come up.

@operation(Operation.ORDER_BY)
def sort( self, *cols: Union[str, sqlglot.dataframe.sql.Column], ascending: Union[Any, List[Any], NoneType] = None) -> sqlglot.dataframe.sql.DataFrame:
512    @operation(Operation.ORDER_BY)
513    def orderBy(
514        self,
515        *cols: t.Union[str, Column],
516        ascending: t.Optional[t.Union[t.Any, t.List[t.Any]]] = None,
517    ) -> DataFrame:
518        """
519        This implementation lets any ordered columns take priority over whatever is provided in `ascending`. Spark
520        has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this
521        is unlikely to come up.
522        """
523        columns = self._ensure_and_normalize_cols(cols)
524        pre_ordered_col_indexes = [
525            x
526            for x in [
527                i if isinstance(col.expression, exp.Ordered) else None
528                for i, col in enumerate(columns)
529            ]
530            if x is not None
531        ]
532        if ascending is None:
533            ascending = [True] * len(columns)
534        elif not isinstance(ascending, list):
535            ascending = [ascending] * len(columns)
536        ascending = [bool(x) for i, x in enumerate(ascending)]
537        assert len(columns) == len(
538            ascending
539        ), "The length of items in ascending must equal the number of columns provided"
540        col_and_ascending = list(zip(columns, ascending))
541        order_by_columns = [
542            exp.Ordered(this=col.expression, desc=not asc)
543            if i not in pre_ordered_col_indexes
544            else columns[i].column_expression
545            for i, (col, asc) in enumerate(col_and_ascending)
546        ]
547        return self.copy(expression=self.expression.order_by(*order_by_columns))

This implementation lets any ordered columns take priority over whatever is provided in ascending. Spark has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this is unlikely to come up.

@operation(Operation.FROM)
def union( self, other: sqlglot.dataframe.sql.DataFrame) -> sqlglot.dataframe.sql.DataFrame:
551    @operation(Operation.FROM)
552    def union(self, other: DataFrame) -> DataFrame:
553        return self._set_operation(exp.Union, other, False)
@operation(Operation.FROM)
def unionAll( self, other: sqlglot.dataframe.sql.DataFrame) -> sqlglot.dataframe.sql.DataFrame:
551    @operation(Operation.FROM)
552    def union(self, other: DataFrame) -> DataFrame:
553        return self._set_operation(exp.Union, other, False)
@operation(Operation.FROM)
def unionByName( self, other: sqlglot.dataframe.sql.DataFrame, allowMissingColumns: bool = False):
557    @operation(Operation.FROM)
558    def unionByName(self, other: DataFrame, allowMissingColumns: bool = False):
559        l_columns = self.columns
560        r_columns = other.columns
561        if not allowMissingColumns:
562            l_expressions = l_columns
563            r_expressions = l_columns
564        else:
565            l_expressions = []
566            r_expressions = []
567            r_columns_unused = copy(r_columns)
568            for l_column in l_columns:
569                l_expressions.append(l_column)
570                if l_column in r_columns:
571                    r_expressions.append(l_column)
572                    r_columns_unused.remove(l_column)
573                else:
574                    r_expressions.append(exp.alias_(exp.Null(), l_column))
575            for r_column in r_columns_unused:
576                l_expressions.append(exp.alias_(exp.Null(), r_column))
577                r_expressions.append(r_column)
578        r_df = (
579            other.copy()._convert_leaf_to_cte().select(*self._ensure_list_of_columns(r_expressions))
580        )
581        l_df = self.copy()
582        if allowMissingColumns:
583            l_df = l_df._convert_leaf_to_cte().select(*self._ensure_list_of_columns(l_expressions))
584        return l_df._set_operation(exp.Union, r_df, False)
@operation(Operation.FROM)
def intersect( self, other: sqlglot.dataframe.sql.DataFrame) -> sqlglot.dataframe.sql.DataFrame:
586    @operation(Operation.FROM)
587    def intersect(self, other: DataFrame) -> DataFrame:
588        return self._set_operation(exp.Intersect, other, True)
@operation(Operation.FROM)
def intersectAll( self, other: sqlglot.dataframe.sql.DataFrame) -> sqlglot.dataframe.sql.DataFrame:
590    @operation(Operation.FROM)
591    def intersectAll(self, other: DataFrame) -> DataFrame:
592        return self._set_operation(exp.Intersect, other, False)
@operation(Operation.FROM)
def exceptAll( self, other: sqlglot.dataframe.sql.DataFrame) -> sqlglot.dataframe.sql.DataFrame:
594    @operation(Operation.FROM)
595    def exceptAll(self, other: DataFrame) -> DataFrame:
596        return self._set_operation(exp.Except, other, False)
@operation(Operation.SELECT)
def distinct(self) -> sqlglot.dataframe.sql.DataFrame:
598    @operation(Operation.SELECT)
599    def distinct(self) -> DataFrame:
600        return self.copy(expression=self.expression.distinct())
@operation(Operation.SELECT)
def dropDuplicates(self, subset: Optional[List[str]] = None):
602    @operation(Operation.SELECT)
603    def dropDuplicates(self, subset: t.Optional[t.List[str]] = None):
604        if not subset:
605            return self.distinct()
606        column_names = ensure_list(subset)
607        window = Window.partitionBy(*column_names).orderBy(*column_names)
608        return (
609            self.copy()
610            .withColumn("row_num", F.row_number().over(window))
611            .where(F.col("row_num") == F.lit(1))
612            .drop("row_num")
613        )
@operation(Operation.FROM)
def dropna( self, how: str = 'any', thresh: Optional[int] = None, subset: Union[str, Tuple[str, ...], List[str], NoneType] = None) -> sqlglot.dataframe.sql.DataFrame:
615    @operation(Operation.FROM)
616    def dropna(
617        self,
618        how: str = "any",
619        thresh: t.Optional[int] = None,
620        subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
621    ) -> DataFrame:
622        minimum_non_null = thresh or 0  # will be determined later if thresh is null
623        new_df = self.copy()
624        all_columns = self._get_outer_select_columns(new_df.expression)
625        if subset:
626            null_check_columns = self._ensure_and_normalize_cols(subset)
627        else:
628            null_check_columns = all_columns
629        if thresh is None:
630            minimum_num_nulls = 1 if how == "any" else len(null_check_columns)
631        else:
632            minimum_num_nulls = len(null_check_columns) - minimum_non_null + 1
633        if minimum_num_nulls > len(null_check_columns):
634            raise RuntimeError(
635                f"The minimum num nulls for dropna must be less than or equal to the number of columns. "
636                f"Minimum num nulls: {minimum_num_nulls}, Num Columns: {len(null_check_columns)}"
637            )
638        if_null_checks = [
639            F.when(column.isNull(), F.lit(1)).otherwise(F.lit(0)) for column in null_check_columns
640        ]
641        nulls_added_together = functools.reduce(lambda x, y: x + y, if_null_checks)
642        num_nulls = nulls_added_together.alias("num_nulls")
643        new_df = new_df.select(num_nulls, append=True)
644        filtered_df = new_df.where(F.col("num_nulls") < F.lit(minimum_num_nulls))
645        final_df = filtered_df.select(*all_columns)
646        return final_df
@operation(Operation.FROM)
def fillna( self, value: <MagicMock id='140377707652544'>, subset: Union[str, Tuple[str, ...], List[str], NoneType] = None) -> sqlglot.dataframe.sql.DataFrame:
648    @operation(Operation.FROM)
649    def fillna(
650        self,
651        value: t.Union[ColumnLiterals],
652        subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
653    ) -> DataFrame:
654        """
655        Functionality Difference: If you provide a value to replace a null and that type conflicts
656        with the type of the column then PySpark will just ignore your replacement.
657        This will try to cast them to be the same in some cases. So they won't always match.
658        Best to not mix types so make sure replacement is the same type as the column
659
660        Possibility for improvement: Use `typeof` function to get the type of the column
661        and check if it matches the type of the value provided. If not then make it null.
662        """
663        from sqlglot.dataframe.sql.functions import lit
664
665        values = None
666        columns = None
667        new_df = self.copy()
668        all_columns = self._get_outer_select_columns(new_df.expression)
669        all_column_mapping = {column.alias_or_name: column for column in all_columns}
670        if isinstance(value, dict):
671            values = list(value.values())
672            columns = self._ensure_and_normalize_cols(list(value))
673        if not columns:
674            columns = self._ensure_and_normalize_cols(subset) if subset else all_columns
675        if not values:
676            values = [value] * len(columns)
677        value_columns = [lit(value) for value in values]
678
679        null_replacement_mapping = {
680            column.alias_or_name: (
681                F.when(column.isNull(), value).otherwise(column).alias(column.alias_or_name)
682            )
683            for column, value in zip(columns, value_columns)
684        }
685        null_replacement_mapping = {**all_column_mapping, **null_replacement_mapping}
686        null_replacement_columns = [
687            null_replacement_mapping[column.alias_or_name] for column in all_columns
688        ]
689        new_df = new_df.select(*null_replacement_columns)
690        return new_df

Functionality Difference: If you provide a value to replace a null and that type conflicts with the type of the column then PySpark will just ignore your replacement. This will try to cast them to be the same in some cases. So they won't always match. Best to not mix types so make sure replacement is the same type as the column

Possibility for improvement: Use typeof function to get the type of the column and check if it matches the type of the value provided. If not then make it null.

@operation(Operation.FROM)
def replace( self, to_replace: Union[bool, int, float, str, List, Dict], value: Union[bool, int, float, str, List, NoneType] = None, subset: Union[Collection[<MagicMock id='140377707950480'>], <MagicMock id='140377707950480'>, NoneType] = None) -> sqlglot.dataframe.sql.DataFrame:
692    @operation(Operation.FROM)
693    def replace(
694        self,
695        to_replace: t.Union[bool, int, float, str, t.List, t.Dict],
696        value: t.Optional[t.Union[bool, int, float, str, t.List]] = None,
697        subset: t.Optional[t.Collection[ColumnOrName] | ColumnOrName] = None,
698    ) -> DataFrame:
699        from sqlglot.dataframe.sql.functions import lit
700
701        old_values = None
702        new_df = self.copy()
703        all_columns = self._get_outer_select_columns(new_df.expression)
704        all_column_mapping = {column.alias_or_name: column for column in all_columns}
705
706        columns = self._ensure_and_normalize_cols(subset) if subset else all_columns
707        if isinstance(to_replace, dict):
708            old_values = list(to_replace)
709            new_values = list(to_replace.values())
710        elif not old_values and isinstance(to_replace, list):
711            assert isinstance(value, list), "value must be a list since the replacements are a list"
712            assert len(to_replace) == len(
713                value
714            ), "the replacements and values must be the same length"
715            old_values = to_replace
716            new_values = value
717        else:
718            old_values = [to_replace] * len(columns)
719            new_values = [value] * len(columns)
720        old_values = [lit(value) for value in old_values]
721        new_values = [lit(value) for value in new_values]
722
723        replacement_mapping = {}
724        for column in columns:
725            expression = Column(None)
726            for i, (old_value, new_value) in enumerate(zip(old_values, new_values)):
727                if i == 0:
728                    expression = F.when(column == old_value, new_value)
729                else:
730                    expression = expression.when(column == old_value, new_value)  # type: ignore
731            replacement_mapping[column.alias_or_name] = expression.otherwise(column).alias(
732                column.expression.alias_or_name
733            )
734
735        replacement_mapping = {**all_column_mapping, **replacement_mapping}
736        replacement_columns = [replacement_mapping[column.alias_or_name] for column in all_columns]
737        new_df = new_df.select(*replacement_columns)
738        return new_df
@operation(Operation.SELECT)
def withColumn( self, colName: str, col: sqlglot.dataframe.sql.Column) -> sqlglot.dataframe.sql.DataFrame:
740    @operation(Operation.SELECT)
741    def withColumn(self, colName: str, col: Column) -> DataFrame:
742        col = self._ensure_and_normalize_col(col)
743        existing_col_names = self.expression.named_selects
744        existing_col_index = (
745            existing_col_names.index(colName) if colName in existing_col_names else None
746        )
747        if existing_col_index:
748            expression = self.expression.copy()
749            expression.expressions[existing_col_index] = col.expression
750            return self.copy(expression=expression)
751        return self.copy().select(col.alias(colName), append=True)
@operation(Operation.SELECT)
def withColumnRenamed(self, existing: str, new: str):
753    @operation(Operation.SELECT)
754    def withColumnRenamed(self, existing: str, new: str):
755        expression = self.expression.copy()
756        existing_columns = [
757            expression
758            for expression in expression.expressions
759            if expression.alias_or_name == existing
760        ]
761        if not existing_columns:
762            raise ValueError("Tried to rename a column that doesn't exist")
763        for existing_column in existing_columns:
764            if isinstance(existing_column, exp.Column):
765                existing_column.replace(exp.alias_(existing_column.copy(), new))
766            else:
767                existing_column.set("alias", exp.to_identifier(new))
768        return self.copy(expression=expression)
@operation(Operation.SELECT)
def drop( self, *cols: Union[str, sqlglot.dataframe.sql.Column]) -> sqlglot.dataframe.sql.DataFrame:
770    @operation(Operation.SELECT)
771    def drop(self, *cols: t.Union[str, Column]) -> DataFrame:
772        all_columns = self._get_outer_select_columns(self.expression)
773        drop_cols = self._ensure_and_normalize_cols(cols)
774        new_columns = [
775            col
776            for col in all_columns
777            if col.alias_or_name not in [drop_column.alias_or_name for drop_column in drop_cols]
778        ]
779        return self.copy().select(*new_columns, append=False)
@operation(Operation.LIMIT)
def limit(self, num: int) -> sqlglot.dataframe.sql.DataFrame:
781    @operation(Operation.LIMIT)
782    def limit(self, num: int) -> DataFrame:
783        return self.copy(expression=self.expression.limit(num))
@operation(Operation.NO_OP)
def hint( self, name: str, *parameters: Union[str, int, NoneType]) -> sqlglot.dataframe.sql.DataFrame:
785    @operation(Operation.NO_OP)
786    def hint(self, name: str, *parameters: t.Optional[t.Union[str, int]]) -> DataFrame:
787        parameter_list = ensure_list(parameters)
788        parameter_columns = (
789            self._ensure_list_of_columns(parameter_list)
790            if parameters
791            else Column.ensure_cols([self.sequence_id])
792        )
793        return self._hint(name, parameter_columns)
@operation(Operation.NO_OP)
def repartition( self, numPartitions: Union[int, <MagicMock id='140377708152048'>], *cols: <MagicMock id='140377708084144'>) -> sqlglot.dataframe.sql.DataFrame:
795    @operation(Operation.NO_OP)
796    def repartition(
797        self, numPartitions: t.Union[int, ColumnOrName], *cols: ColumnOrName
798    ) -> DataFrame:
799        num_partition_cols = self._ensure_list_of_columns(numPartitions)
800        columns = self._ensure_and_normalize_cols(cols)
801        args = num_partition_cols + columns
802        return self._hint("repartition", args)
@operation(Operation.NO_OP)
def coalesce(self, numPartitions: int) -> sqlglot.dataframe.sql.DataFrame:
804    @operation(Operation.NO_OP)
805    def coalesce(self, numPartitions: int) -> DataFrame:
806        num_partitions = Column.ensure_cols([numPartitions])
807        return self._hint("coalesce", num_partitions)
@operation(Operation.NO_OP)
def cache(self) -> sqlglot.dataframe.sql.DataFrame:
809    @operation(Operation.NO_OP)
810    def cache(self) -> DataFrame:
811        return self._cache(storage_level="MEMORY_AND_DISK")
@operation(Operation.NO_OP)
def persist( self, storageLevel: str = 'MEMORY_AND_DISK_SER') -> sqlglot.dataframe.sql.DataFrame:
813    @operation(Operation.NO_OP)
814    def persist(self, storageLevel: str = "MEMORY_AND_DISK_SER") -> DataFrame:
815        """
816        Storage Level Options: https://spark.apache.org/docs/3.0.0-preview/sql-ref-syntax-aux-cache-cache-table.html
817        """
818        return self._cache(storageLevel)
class GroupedData:
14class GroupedData:
15    def __init__(self, df: DataFrame, group_by_cols: t.List[Column], last_op: Operation):
16        self._df = df.copy()
17        self.spark = df.spark
18        self.last_op = last_op
19        self.group_by_cols = group_by_cols
20
21    def _get_function_applied_columns(
22        self, func_name: str, cols: t.Tuple[str, ...]
23    ) -> t.List[Column]:
24        func_name = func_name.lower()
25        return [getattr(F, func_name)(name).alias(f"{func_name}({name})") for name in cols]
26
27    @operation(Operation.SELECT)
28    def agg(self, *exprs: t.Union[Column, t.Dict[str, str]]) -> DataFrame:
29        columns = (
30            [Column(f"{agg_func}({column_name})") for column_name, agg_func in exprs[0].items()]
31            if isinstance(exprs[0], dict)
32            else exprs
33        )
34        cols = self._df._ensure_and_normalize_cols(columns)
35
36        expression = self._df.expression.group_by(
37            *[x.expression for x in self.group_by_cols]
38        ).select(*[x.expression for x in self.group_by_cols + cols], append=False)
39        return self._df.copy(expression=expression)
40
41    def count(self) -> DataFrame:
42        return self.agg(F.count("*").alias("count"))
43
44    def mean(self, *cols: str) -> DataFrame:
45        return self.avg(*cols)
46
47    def avg(self, *cols: str) -> DataFrame:
48        return self.agg(*self._get_function_applied_columns("avg", cols))
49
50    def max(self, *cols: str) -> DataFrame:
51        return self.agg(*self._get_function_applied_columns("max", cols))
52
53    def min(self, *cols: str) -> DataFrame:
54        return self.agg(*self._get_function_applied_columns("min", cols))
55
56    def sum(self, *cols: str) -> DataFrame:
57        return self.agg(*self._get_function_applied_columns("sum", cols))
58
59    def pivot(self, *cols: str) -> DataFrame:
60        raise NotImplementedError("Sum distinct is not currently implemented")
GroupedData( df: sqlglot.dataframe.sql.DataFrame, group_by_cols: List[sqlglot.dataframe.sql.Column], last_op: sqlglot.dataframe.sql.operations.Operation)
15    def __init__(self, df: DataFrame, group_by_cols: t.List[Column], last_op: Operation):
16        self._df = df.copy()
17        self.spark = df.spark
18        self.last_op = last_op
19        self.group_by_cols = group_by_cols
@operation(Operation.SELECT)
def agg( self, *exprs: Union[sqlglot.dataframe.sql.Column, Dict[str, str]]) -> sqlglot.dataframe.sql.DataFrame:
27    @operation(Operation.SELECT)
28    def agg(self, *exprs: t.Union[Column, t.Dict[str, str]]) -> DataFrame:
29        columns = (
30            [Column(f"{agg_func}({column_name})") for column_name, agg_func in exprs[0].items()]
31            if isinstance(exprs[0], dict)
32            else exprs
33        )
34        cols = self._df._ensure_and_normalize_cols(columns)
35
36        expression = self._df.expression.group_by(
37            *[x.expression for x in self.group_by_cols]
38        ).select(*[x.expression for x in self.group_by_cols + cols], append=False)
39        return self._df.copy(expression=expression)
def count(self) -> sqlglot.dataframe.sql.DataFrame:
41    def count(self) -> DataFrame:
42        return self.agg(F.count("*").alias("count"))
def mean(self, *cols: str) -> sqlglot.dataframe.sql.DataFrame:
44    def mean(self, *cols: str) -> DataFrame:
45        return self.avg(*cols)
def avg(self, *cols: str) -> sqlglot.dataframe.sql.DataFrame:
47    def avg(self, *cols: str) -> DataFrame:
48        return self.agg(*self._get_function_applied_columns("avg", cols))
def max(self, *cols: str) -> sqlglot.dataframe.sql.DataFrame:
50    def max(self, *cols: str) -> DataFrame:
51        return self.agg(*self._get_function_applied_columns("max", cols))
def min(self, *cols: str) -> sqlglot.dataframe.sql.DataFrame:
53    def min(self, *cols: str) -> DataFrame:
54        return self.agg(*self._get_function_applied_columns("min", cols))
def sum(self, *cols: str) -> sqlglot.dataframe.sql.DataFrame:
56    def sum(self, *cols: str) -> DataFrame:
57        return self.agg(*self._get_function_applied_columns("sum", cols))
def pivot(self, *cols: str) -> sqlglot.dataframe.sql.DataFrame:
59    def pivot(self, *cols: str) -> DataFrame:
60        raise NotImplementedError("Sum distinct is not currently implemented")
class Column:
 16class Column:
 17    def __init__(self, expression: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]):
 18        if isinstance(expression, Column):
 19            expression = expression.expression  # type: ignore
 20        elif expression is None or not isinstance(expression, (str, exp.Expression)):
 21            expression = self._lit(expression).expression  # type: ignore
 22
 23        expression = sqlglot.maybe_parse(expression, dialect="spark")
 24        if expression is None:
 25            raise ValueError(f"Could not parse {expression}")
 26        self.expression: exp.Expression = expression
 27
 28    def __repr__(self):
 29        return repr(self.expression)
 30
 31    def __hash__(self):
 32        return hash(self.expression)
 33
 34    def __eq__(self, other: ColumnOrLiteral) -> Column:  # type: ignore
 35        return self.binary_op(exp.EQ, other)
 36
 37    def __ne__(self, other: ColumnOrLiteral) -> Column:  # type: ignore
 38        return self.binary_op(exp.NEQ, other)
 39
 40    def __gt__(self, other: ColumnOrLiteral) -> Column:
 41        return self.binary_op(exp.GT, other)
 42
 43    def __ge__(self, other: ColumnOrLiteral) -> Column:
 44        return self.binary_op(exp.GTE, other)
 45
 46    def __lt__(self, other: ColumnOrLiteral) -> Column:
 47        return self.binary_op(exp.LT, other)
 48
 49    def __le__(self, other: ColumnOrLiteral) -> Column:
 50        return self.binary_op(exp.LTE, other)
 51
 52    def __and__(self, other: ColumnOrLiteral) -> Column:
 53        return self.binary_op(exp.And, other)
 54
 55    def __or__(self, other: ColumnOrLiteral) -> Column:
 56        return self.binary_op(exp.Or, other)
 57
 58    def __mod__(self, other: ColumnOrLiteral) -> Column:
 59        return self.binary_op(exp.Mod, other)
 60
 61    def __add__(self, other: ColumnOrLiteral) -> Column:
 62        return self.binary_op(exp.Add, other)
 63
 64    def __sub__(self, other: ColumnOrLiteral) -> Column:
 65        return self.binary_op(exp.Sub, other)
 66
 67    def __mul__(self, other: ColumnOrLiteral) -> Column:
 68        return self.binary_op(exp.Mul, other)
 69
 70    def __truediv__(self, other: ColumnOrLiteral) -> Column:
 71        return self.binary_op(exp.Div, other)
 72
 73    def __div__(self, other: ColumnOrLiteral) -> Column:
 74        return self.binary_op(exp.Div, other)
 75
 76    def __neg__(self) -> Column:
 77        return self.unary_op(exp.Neg)
 78
 79    def __radd__(self, other: ColumnOrLiteral) -> Column:
 80        return self.inverse_binary_op(exp.Add, other)
 81
 82    def __rsub__(self, other: ColumnOrLiteral) -> Column:
 83        return self.inverse_binary_op(exp.Sub, other)
 84
 85    def __rmul__(self, other: ColumnOrLiteral) -> Column:
 86        return self.inverse_binary_op(exp.Mul, other)
 87
 88    def __rdiv__(self, other: ColumnOrLiteral) -> Column:
 89        return self.inverse_binary_op(exp.Div, other)
 90
 91    def __rtruediv__(self, other: ColumnOrLiteral) -> Column:
 92        return self.inverse_binary_op(exp.Div, other)
 93
 94    def __rmod__(self, other: ColumnOrLiteral) -> Column:
 95        return self.inverse_binary_op(exp.Mod, other)
 96
 97    def __pow__(self, power: ColumnOrLiteral, modulo=None):
 98        return Column(exp.Pow(this=self.expression, expression=Column(power).expression))
 99
100    def __rpow__(self, power: ColumnOrLiteral):
101        return Column(exp.Pow(this=Column(power).expression, expression=self.expression))
102
103    def __invert__(self):
104        return self.unary_op(exp.Not)
105
106    def __rand__(self, other: ColumnOrLiteral) -> Column:
107        return self.inverse_binary_op(exp.And, other)
108
109    def __ror__(self, other: ColumnOrLiteral) -> Column:
110        return self.inverse_binary_op(exp.Or, other)
111
112    @classmethod
113    def ensure_col(cls, value: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]):
114        return cls(value)
115
116    @classmethod
117    def ensure_cols(cls, args: t.List[t.Union[ColumnOrLiteral, exp.Expression]]) -> t.List[Column]:
118        return [cls.ensure_col(x) if not isinstance(x, Column) else x for x in args]
119
120    @classmethod
121    def _lit(cls, value: ColumnOrLiteral) -> Column:
122        if isinstance(value, dict):
123            columns = [cls._lit(v).alias(k).expression for k, v in value.items()]
124            return cls(exp.Struct(expressions=columns))
125        return cls(exp.convert(value))
126
127    @classmethod
128    def invoke_anonymous_function(
129        cls, column: t.Optional[ColumnOrLiteral], func_name: str, *args: t.Optional[ColumnOrLiteral]
130    ) -> Column:
131        columns = [] if column is None else [cls.ensure_col(column)]
132        column_args = [cls.ensure_col(arg) for arg in args]
133        expressions = [x.expression for x in columns + column_args]
134        new_expression = exp.Anonymous(this=func_name.upper(), expressions=expressions)
135        return Column(new_expression)
136
137    @classmethod
138    def invoke_expression_over_column(
139        cls, column: t.Optional[ColumnOrLiteral], callable_expression: t.Callable, **kwargs
140    ) -> Column:
141        ensured_column = None if column is None else cls.ensure_col(column)
142        ensure_expression_values = {
143            k: [Column.ensure_col(x).expression for x in v]
144            if is_iterable(v)
145            else Column.ensure_col(v).expression
146            for k, v in kwargs.items()
147            if v is not None
148        }
149        new_expression = (
150            callable_expression(**ensure_expression_values)
151            if ensured_column is None
152            else callable_expression(
153                this=ensured_column.column_expression, **ensure_expression_values
154            )
155        )
156        return Column(new_expression)
157
158    def binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column:
159        return Column(
160            klass(this=self.column_expression, expression=Column(other).column_expression, **kwargs)
161        )
162
163    def inverse_binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column:
164        return Column(
165            klass(this=Column(other).column_expression, expression=self.column_expression, **kwargs)
166        )
167
168    def unary_op(self, klass: t.Callable, **kwargs) -> Column:
169        return Column(klass(this=self.column_expression, **kwargs))
170
171    @property
172    def is_alias(self):
173        return isinstance(self.expression, exp.Alias)
174
175    @property
176    def is_column(self):
177        return isinstance(self.expression, exp.Column)
178
179    @property
180    def column_expression(self) -> t.Union[exp.Column, exp.Literal]:
181        return self.expression.unalias()
182
183    @property
184    def alias_or_name(self) -> str:
185        return self.expression.alias_or_name
186
187    @classmethod
188    def ensure_literal(cls, value) -> Column:
189        from sqlglot.dataframe.sql.functions import lit
190
191        if isinstance(value, cls):
192            value = value.expression
193        if not isinstance(value, exp.Literal):
194            return lit(value)
195        return Column(value)
196
197    def copy(self) -> Column:
198        return Column(self.expression.copy())
199
200    def set_table_name(self, table_name: str, copy=False) -> Column:
201        expression = self.expression.copy() if copy else self.expression
202        expression.set("table", exp.to_identifier(table_name))
203        return Column(expression)
204
205    def sql(self, **kwargs) -> str:
206        return self.expression.sql(**{"dialect": "spark", **kwargs})
207
208    def alias(self, name: str) -> Column:
209        new_expression = exp.alias_(self.column_expression, name)
210        return Column(new_expression)
211
212    def asc(self) -> Column:
213        new_expression = exp.Ordered(this=self.column_expression, desc=False, nulls_first=True)
214        return Column(new_expression)
215
216    def desc(self) -> Column:
217        new_expression = exp.Ordered(this=self.column_expression, desc=True, nulls_first=False)
218        return Column(new_expression)
219
220    asc_nulls_first = asc
221
222    def asc_nulls_last(self) -> Column:
223        new_expression = exp.Ordered(this=self.column_expression, desc=False, nulls_first=False)
224        return Column(new_expression)
225
226    def desc_nulls_first(self) -> Column:
227        new_expression = exp.Ordered(this=self.column_expression, desc=True, nulls_first=True)
228        return Column(new_expression)
229
230    desc_nulls_last = desc
231
232    def when(self, condition: Column, value: t.Any) -> Column:
233        from sqlglot.dataframe.sql.functions import when
234
235        column_with_if = when(condition, value)
236        if not isinstance(self.expression, exp.Case):
237            return column_with_if
238        new_column = self.copy()
239        new_column.expression.args["ifs"].extend(column_with_if.expression.args["ifs"])
240        return new_column
241
242    def otherwise(self, value: t.Any) -> Column:
243        from sqlglot.dataframe.sql.functions import lit
244
245        true_value = value if isinstance(value, Column) else lit(value)
246        new_column = self.copy()
247        new_column.expression.set("default", true_value.column_expression)
248        return new_column
249
250    def isNull(self) -> Column:
251        new_expression = exp.Is(this=self.column_expression, expression=exp.Null())
252        return Column(new_expression)
253
254    def isNotNull(self) -> Column:
255        new_expression = exp.Not(this=exp.Is(this=self.column_expression, expression=exp.Null()))
256        return Column(new_expression)
257
258    def cast(self, dataType: t.Union[str, DataType]):
259        """
260        Functionality Difference: PySpark cast accepts a datatype instance of the datatype class
261        Sqlglot doesn't currently replicate this class so it only accepts a string
262        """
263        if isinstance(dataType, DataType):
264            dataType = dataType.simpleString()
265        return Column(exp.cast(self.column_expression, dataType, dialect="spark"))
266
267    def startswith(self, value: t.Union[str, Column]) -> Column:
268        value = self._lit(value) if not isinstance(value, Column) else value
269        return self.invoke_anonymous_function(self, "STARTSWITH", value)
270
271    def endswith(self, value: t.Union[str, Column]) -> Column:
272        value = self._lit(value) if not isinstance(value, Column) else value
273        return self.invoke_anonymous_function(self, "ENDSWITH", value)
274
275    def rlike(self, regexp: str) -> Column:
276        return self.invoke_expression_over_column(
277            column=self, callable_expression=exp.RegexpLike, expression=self._lit(regexp).expression
278        )
279
280    def like(self, other: str):
281        return self.invoke_expression_over_column(
282            self, exp.Like, expression=self._lit(other).expression
283        )
284
285    def ilike(self, other: str):
286        return self.invoke_expression_over_column(
287            self, exp.ILike, expression=self._lit(other).expression
288        )
289
290    def substr(self, startPos: t.Union[int, Column], length: t.Union[int, Column]) -> Column:
291        startPos = self._lit(startPos) if not isinstance(startPos, Column) else startPos
292        length = self._lit(length) if not isinstance(length, Column) else length
293        return Column.invoke_expression_over_column(
294            self, exp.Substring, start=startPos.expression, length=length.expression
295        )
296
297    def isin(self, *cols: t.Union[ColumnOrLiteral, t.Iterable[ColumnOrLiteral]]):
298        columns = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols  # type: ignore
299        expressions = [self._lit(x).expression for x in columns]
300        return Column.invoke_expression_over_column(self, exp.In, expressions=expressions)  # type: ignore
301
302    def between(
303        self,
304        lowerBound: t.Union[ColumnOrLiteral],
305        upperBound: t.Union[ColumnOrLiteral],
306    ) -> Column:
307        lower_bound_exp = (
308            self._lit(lowerBound) if not isinstance(lowerBound, Column) else lowerBound
309        )
310        upper_bound_exp = (
311            self._lit(upperBound) if not isinstance(upperBound, Column) else upperBound
312        )
313        return Column(
314            exp.Between(
315                this=self.column_expression,
316                low=lower_bound_exp.expression,
317                high=upper_bound_exp.expression,
318            )
319        )
320
321    def over(self, window: WindowSpec) -> Column:
322        window_expression = window.expression.copy()
323        window_expression.set("this", self.column_expression)
324        return Column(window_expression)
Column( expression: Union[<MagicMock id='140377710043744'>, sqlglot.expressions.Expression, NoneType])
17    def __init__(self, expression: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]):
18        if isinstance(expression, Column):
19            expression = expression.expression  # type: ignore
20        elif expression is None or not isinstance(expression, (str, exp.Expression)):
21            expression = self._lit(expression).expression  # type: ignore
22
23        expression = sqlglot.maybe_parse(expression, dialect="spark")
24        if expression is None:
25            raise ValueError(f"Could not parse {expression}")
26        self.expression: exp.Expression = expression
@classmethod
def ensure_col( cls, value: Union[<MagicMock id='140377708259664'>, sqlglot.expressions.Expression, NoneType]):
112    @classmethod
113    def ensure_col(cls, value: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]):
114        return cls(value)
@classmethod
def ensure_cols( cls, args: List[Union[<MagicMock id='140377708547664'>, sqlglot.expressions.Expression]]) -> List[sqlglot.dataframe.sql.Column]:
116    @classmethod
117    def ensure_cols(cls, args: t.List[t.Union[ColumnOrLiteral, exp.Expression]]) -> t.List[Column]:
118        return [cls.ensure_col(x) if not isinstance(x, Column) else x for x in args]
@classmethod
def invoke_anonymous_function( cls, column: Optional[<MagicMock id='140377708420048'>], func_name: str, *args: Optional[<MagicMock id='140377708523584'>]) -> sqlglot.dataframe.sql.Column:
127    @classmethod
128    def invoke_anonymous_function(
129        cls, column: t.Optional[ColumnOrLiteral], func_name: str, *args: t.Optional[ColumnOrLiteral]
130    ) -> Column:
131        columns = [] if column is None else [cls.ensure_col(column)]
132        column_args = [cls.ensure_col(arg) for arg in args]
133        expressions = [x.expression for x in columns + column_args]
134        new_expression = exp.Anonymous(this=func_name.upper(), expressions=expressions)
135        return Column(new_expression)
@classmethod
def invoke_expression_over_column( cls, column: Optional[<MagicMock id='140377708376176'>], callable_expression: Callable, **kwargs) -> sqlglot.dataframe.sql.Column:
137    @classmethod
138    def invoke_expression_over_column(
139        cls, column: t.Optional[ColumnOrLiteral], callable_expression: t.Callable, **kwargs
140    ) -> Column:
141        ensured_column = None if column is None else cls.ensure_col(column)
142        ensure_expression_values = {
143            k: [Column.ensure_col(x).expression for x in v]
144            if is_iterable(v)
145            else Column.ensure_col(v).expression
146            for k, v in kwargs.items()
147            if v is not None
148        }
149        new_expression = (
150            callable_expression(**ensure_expression_values)
151            if ensured_column is None
152            else callable_expression(
153                this=ensured_column.column_expression, **ensure_expression_values
154            )
155        )
156        return Column(new_expression)
def binary_op( self, klass: Callable, other: <MagicMock id='140377706598112'>, **kwargs) -> sqlglot.dataframe.sql.Column:
158    def binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column:
159        return Column(
160            klass(this=self.column_expression, expression=Column(other).column_expression, **kwargs)
161        )
def inverse_binary_op( self, klass: Callable, other: <MagicMock id='140377706608240'>, **kwargs) -> sqlglot.dataframe.sql.Column:
163    def inverse_binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column:
164        return Column(
165            klass(this=Column(other).column_expression, expression=self.column_expression, **kwargs)
166        )
def unary_op(self, klass: Callable, **kwargs) -> sqlglot.dataframe.sql.Column:
168    def unary_op(self, klass: t.Callable, **kwargs) -> Column:
169        return Column(klass(this=self.column_expression, **kwargs))
@classmethod
def ensure_literal(cls, value) -> sqlglot.dataframe.sql.Column:
187    @classmethod
188    def ensure_literal(cls, value) -> Column:
189        from sqlglot.dataframe.sql.functions import lit
190
191        if isinstance(value, cls):
192            value = value.expression
193        if not isinstance(value, exp.Literal):
194            return lit(value)
195        return Column(value)
def copy(self) -> sqlglot.dataframe.sql.Column:
197    def copy(self) -> Column:
198        return Column(self.expression.copy())
def set_table_name(self, table_name: str, copy=False) -> sqlglot.dataframe.sql.Column:
200    def set_table_name(self, table_name: str, copy=False) -> Column:
201        expression = self.expression.copy() if copy else self.expression
202        expression.set("table", exp.to_identifier(table_name))
203        return Column(expression)
def sql(self, **kwargs) -> str:
205    def sql(self, **kwargs) -> str:
206        return self.expression.sql(**{"dialect": "spark", **kwargs})
def alias(self, name: str) -> sqlglot.dataframe.sql.Column:
208    def alias(self, name: str) -> Column:
209        new_expression = exp.alias_(self.column_expression, name)
210        return Column(new_expression)
def asc(self) -> sqlglot.dataframe.sql.Column:
212    def asc(self) -> Column:
213        new_expression = exp.Ordered(this=self.column_expression, desc=False, nulls_first=True)
214        return Column(new_expression)
def desc(self) -> sqlglot.dataframe.sql.Column:
216    def desc(self) -> Column:
217        new_expression = exp.Ordered(this=self.column_expression, desc=True, nulls_first=False)
218        return Column(new_expression)
def asc_nulls_first(self) -> sqlglot.dataframe.sql.Column:
212    def asc(self) -> Column:
213        new_expression = exp.Ordered(this=self.column_expression, desc=False, nulls_first=True)
214        return Column(new_expression)
def asc_nulls_last(self) -> sqlglot.dataframe.sql.Column:
222    def asc_nulls_last(self) -> Column:
223        new_expression = exp.Ordered(this=self.column_expression, desc=False, nulls_first=False)
224        return Column(new_expression)
def desc_nulls_first(self) -> sqlglot.dataframe.sql.Column:
226    def desc_nulls_first(self) -> Column:
227        new_expression = exp.Ordered(this=self.column_expression, desc=True, nulls_first=True)
228        return Column(new_expression)
def desc_nulls_last(self) -> sqlglot.dataframe.sql.Column:
216    def desc(self) -> Column:
217        new_expression = exp.Ordered(this=self.column_expression, desc=True, nulls_first=False)
218        return Column(new_expression)
def when( self, condition: sqlglot.dataframe.sql.Column, value: Any) -> sqlglot.dataframe.sql.Column:
232    def when(self, condition: Column, value: t.Any) -> Column:
233        from sqlglot.dataframe.sql.functions import when
234
235        column_with_if = when(condition, value)
236        if not isinstance(self.expression, exp.Case):
237            return column_with_if
238        new_column = self.copy()
239        new_column.expression.args["ifs"].extend(column_with_if.expression.args["ifs"])
240        return new_column
def otherwise(self, value: Any) -> sqlglot.dataframe.sql.Column:
242    def otherwise(self, value: t.Any) -> Column:
243        from sqlglot.dataframe.sql.functions import lit
244
245        true_value = value if isinstance(value, Column) else lit(value)
246        new_column = self.copy()
247        new_column.expression.set("default", true_value.column_expression)
248        return new_column
def isNull(self) -> sqlglot.dataframe.sql.Column:
250    def isNull(self) -> Column:
251        new_expression = exp.Is(this=self.column_expression, expression=exp.Null())
252        return Column(new_expression)
def isNotNull(self) -> sqlglot.dataframe.sql.Column:
254    def isNotNull(self) -> Column:
255        new_expression = exp.Not(this=exp.Is(this=self.column_expression, expression=exp.Null()))
256        return Column(new_expression)
def cast(self, dataType: Union[str, sqlglot.dataframe.sql.types.DataType]):
258    def cast(self, dataType: t.Union[str, DataType]):
259        """
260        Functionality Difference: PySpark cast accepts a datatype instance of the datatype class
261        Sqlglot doesn't currently replicate this class so it only accepts a string
262        """
263        if isinstance(dataType, DataType):
264            dataType = dataType.simpleString()
265        return Column(exp.cast(self.column_expression, dataType, dialect="spark"))

Functionality Difference: PySpark cast accepts a datatype instance of the datatype class Sqlglot doesn't currently replicate this class so it only accepts a string

def startswith( self, value: Union[str, sqlglot.dataframe.sql.Column]) -> sqlglot.dataframe.sql.Column:
267    def startswith(self, value: t.Union[str, Column]) -> Column:
268        value = self._lit(value) if not isinstance(value, Column) else value
269        return self.invoke_anonymous_function(self, "STARTSWITH", value)
def endswith( self, value: Union[str, sqlglot.dataframe.sql.Column]) -> sqlglot.dataframe.sql.Column:
271    def endswith(self, value: t.Union[str, Column]) -> Column:
272        value = self._lit(value) if not isinstance(value, Column) else value
273        return self.invoke_anonymous_function(self, "ENDSWITH", value)
def rlike(self, regexp: str) -> sqlglot.dataframe.sql.Column:
275    def rlike(self, regexp: str) -> Column:
276        return self.invoke_expression_over_column(
277            column=self, callable_expression=exp.RegexpLike, expression=self._lit(regexp).expression
278        )
def like(self, other: str):
280    def like(self, other: str):
281        return self.invoke_expression_over_column(
282            self, exp.Like, expression=self._lit(other).expression
283        )
def ilike(self, other: str):
285    def ilike(self, other: str):
286        return self.invoke_expression_over_column(
287            self, exp.ILike, expression=self._lit(other).expression
288        )
def substr( self, startPos: Union[int, sqlglot.dataframe.sql.Column], length: Union[int, sqlglot.dataframe.sql.Column]) -> sqlglot.dataframe.sql.Column:
290    def substr(self, startPos: t.Union[int, Column], length: t.Union[int, Column]) -> Column:
291        startPos = self._lit(startPos) if not isinstance(startPos, Column) else startPos
292        length = self._lit(length) if not isinstance(length, Column) else length
293        return Column.invoke_expression_over_column(
294            self, exp.Substring, start=startPos.expression, length=length.expression
295        )
def isin( self, *cols: Union[<MagicMock id='140377706743696'>, Iterable[<MagicMock id='140377706743696'>]]):
297    def isin(self, *cols: t.Union[ColumnOrLiteral, t.Iterable[ColumnOrLiteral]]):
298        columns = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols  # type: ignore
299        expressions = [self._lit(x).expression for x in columns]
300        return Column.invoke_expression_over_column(self, exp.In, expressions=expressions)  # type: ignore
def between( self, lowerBound: <MagicMock id='140377706814464'>, upperBound: <MagicMock id='140377706852560'>) -> sqlglot.dataframe.sql.Column:
302    def between(
303        self,
304        lowerBound: t.Union[ColumnOrLiteral],
305        upperBound: t.Union[ColumnOrLiteral],
306    ) -> Column:
307        lower_bound_exp = (
308            self._lit(lowerBound) if not isinstance(lowerBound, Column) else lowerBound
309        )
310        upper_bound_exp = (
311            self._lit(upperBound) if not isinstance(upperBound, Column) else upperBound
312        )
313        return Column(
314            exp.Between(
315                this=self.column_expression,
316                low=lower_bound_exp.expression,
317                high=upper_bound_exp.expression,
318            )
319        )
def over( self, window: <MagicMock id='140377706910272'>) -> sqlglot.dataframe.sql.Column:
321    def over(self, window: WindowSpec) -> Column:
322        window_expression = window.expression.copy()
323        window_expression.set("this", self.column_expression)
324        return Column(window_expression)
class DataFrameNaFunctions:
821class DataFrameNaFunctions:
822    def __init__(self, df: DataFrame):
823        self.df = df
824
825    def drop(
826        self,
827        how: str = "any",
828        thresh: t.Optional[int] = None,
829        subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
830    ) -> DataFrame:
831        return self.df.dropna(how=how, thresh=thresh, subset=subset)
832
833    def fill(
834        self,
835        value: t.Union[int, bool, float, str, t.Dict[str, t.Any]],
836        subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
837    ) -> DataFrame:
838        return self.df.fillna(value=value, subset=subset)
839
840    def replace(
841        self,
842        to_replace: t.Union[bool, int, float, str, t.List, t.Dict],
843        value: t.Optional[t.Union[bool, int, float, str, t.List]] = None,
844        subset: t.Optional[t.Union[str, t.List[str]]] = None,
845    ) -> DataFrame:
846        return self.df.replace(to_replace=to_replace, value=value, subset=subset)
DataFrameNaFunctions(df: sqlglot.dataframe.sql.DataFrame)
822    def __init__(self, df: DataFrame):
823        self.df = df
def drop( self, how: str = 'any', thresh: Optional[int] = None, subset: Union[str, Tuple[str, ...], List[str], NoneType] = None) -> sqlglot.dataframe.sql.DataFrame:
825    def drop(
826        self,
827        how: str = "any",
828        thresh: t.Optional[int] = None,
829        subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
830    ) -> DataFrame:
831        return self.df.dropna(how=how, thresh=thresh, subset=subset)
def fill( self, value: Union[int, bool, float, str, Dict[str, Any]], subset: Union[str, Tuple[str, ...], List[str], NoneType] = None) -> sqlglot.dataframe.sql.DataFrame:
833    def fill(
834        self,
835        value: t.Union[int, bool, float, str, t.Dict[str, t.Any]],
836        subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
837    ) -> DataFrame:
838        return self.df.fillna(value=value, subset=subset)
def replace( self, to_replace: Union[bool, int, float, str, List, Dict], value: Union[bool, int, float, str, List, NoneType] = None, subset: Union[str, List[str], NoneType] = None) -> sqlglot.dataframe.sql.DataFrame:
840    def replace(
841        self,
842        to_replace: t.Union[bool, int, float, str, t.List, t.Dict],
843        value: t.Optional[t.Union[bool, int, float, str, t.List]] = None,
844        subset: t.Optional[t.Union[str, t.List[str]]] = None,
845    ) -> DataFrame:
846        return self.df.replace(to_replace=to_replace, value=value, subset=subset)
class Window:
15class Window:
16    _JAVA_MIN_LONG = -(1 << 63)  # -9223372036854775808
17    _JAVA_MAX_LONG = (1 << 63) - 1  # 9223372036854775807
18    _PRECEDING_THRESHOLD = max(-sys.maxsize, _JAVA_MIN_LONG)
19    _FOLLOWING_THRESHOLD = min(sys.maxsize, _JAVA_MAX_LONG)
20
21    unboundedPreceding: int = _JAVA_MIN_LONG
22
23    unboundedFollowing: int = _JAVA_MAX_LONG
24
25    currentRow: int = 0
26
27    @classmethod
28    def partitionBy(cls, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
29        return WindowSpec().partitionBy(*cols)
30
31    @classmethod
32    def orderBy(cls, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
33        return WindowSpec().orderBy(*cols)
34
35    @classmethod
36    def rowsBetween(cls, start: int, end: int) -> WindowSpec:
37        return WindowSpec().rowsBetween(start, end)
38
39    @classmethod
40    def rangeBetween(cls, start: int, end: int) -> WindowSpec:
41        return WindowSpec().rangeBetween(start, end)
@classmethod
def partitionBy( cls, *cols: Union[<MagicMock id='140377707221328'>, List[<MagicMock id='140377707221328'>]]) -> sqlglot.dataframe.sql.WindowSpec:
27    @classmethod
28    def partitionBy(cls, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
29        return WindowSpec().partitionBy(*cols)
@classmethod
def orderBy( cls, *cols: Union[<MagicMock id='140377707360368'>, List[<MagicMock id='140377707360368'>]]) -> sqlglot.dataframe.sql.WindowSpec:
31    @classmethod
32    def orderBy(cls, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
33        return WindowSpec().orderBy(*cols)
@classmethod
def rowsBetween(cls, start: int, end: int) -> sqlglot.dataframe.sql.WindowSpec:
35    @classmethod
36    def rowsBetween(cls, start: int, end: int) -> WindowSpec:
37        return WindowSpec().rowsBetween(start, end)
@classmethod
def rangeBetween(cls, start: int, end: int) -> sqlglot.dataframe.sql.WindowSpec:
39    @classmethod
40    def rangeBetween(cls, start: int, end: int) -> WindowSpec:
41        return WindowSpec().rangeBetween(start, end)
class WindowSpec:
 44class WindowSpec:
 45    def __init__(self, expression: exp.Expression = exp.Window()):
 46        self.expression = expression
 47
 48    def copy(self):
 49        return WindowSpec(self.expression.copy())
 50
 51    def sql(self, **kwargs) -> str:
 52        return self.expression.sql(dialect="spark", **kwargs)
 53
 54    def partitionBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
 55        from sqlglot.dataframe.sql.column import Column
 56
 57        cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols  # type: ignore
 58        expressions = [Column.ensure_col(x).expression for x in cols]
 59        window_spec = self.copy()
 60        partition_by_expressions = window_spec.expression.args.get("partition_by", [])
 61        partition_by_expressions.extend(expressions)
 62        window_spec.expression.set("partition_by", partition_by_expressions)
 63        return window_spec
 64
 65    def orderBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
 66        from sqlglot.dataframe.sql.column import Column
 67
 68        cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols  # type: ignore
 69        expressions = [Column.ensure_col(x).expression for x in cols]
 70        window_spec = self.copy()
 71        if window_spec.expression.args.get("order") is None:
 72            window_spec.expression.set("order", exp.Order(expressions=[]))
 73        order_by = window_spec.expression.args["order"].expressions
 74        order_by.extend(expressions)
 75        window_spec.expression.args["order"].set("expressions", order_by)
 76        return window_spec
 77
 78    def _calc_start_end(
 79        self, start: int, end: int
 80    ) -> t.Dict[str, t.Optional[t.Union[str, exp.Expression]]]:
 81        kwargs: t.Dict[str, t.Optional[t.Union[str, exp.Expression]]] = {
 82            "start_side": None,
 83            "end_side": None,
 84        }
 85        if start == Window.currentRow:
 86            kwargs["start"] = "CURRENT ROW"
 87        else:
 88            kwargs = {
 89                **kwargs,
 90                **{
 91                    "start_side": "PRECEDING",
 92                    "start": "UNBOUNDED"
 93                    if start <= Window.unboundedPreceding
 94                    else F.lit(start).expression,
 95                },
 96            }
 97        if end == Window.currentRow:
 98            kwargs["end"] = "CURRENT ROW"
 99        else:
100            kwargs = {
101                **kwargs,
102                **{
103                    "end_side": "FOLLOWING",
104                    "end": "UNBOUNDED"
105                    if end >= Window.unboundedFollowing
106                    else F.lit(end).expression,
107                },
108            }
109        return kwargs
110
111    def rowsBetween(self, start: int, end: int) -> WindowSpec:
112        window_spec = self.copy()
113        spec = self._calc_start_end(start, end)
114        spec["kind"] = "ROWS"
115        window_spec.expression.set(
116            "spec",
117            exp.WindowSpec(
118                **{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec}
119            ),
120        )
121        return window_spec
122
123    def rangeBetween(self, start: int, end: int) -> WindowSpec:
124        window_spec = self.copy()
125        spec = self._calc_start_end(start, end)
126        spec["kind"] = "RANGE"
127        window_spec.expression.set(
128            "spec",
129            exp.WindowSpec(
130                **{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec}
131            ),
132        )
133        return window_spec
WindowSpec(expression: sqlglot.expressions.Expression = (WINDOW ))
45    def __init__(self, expression: exp.Expression = exp.Window()):
46        self.expression = expression
def copy(self):
48    def copy(self):
49        return WindowSpec(self.expression.copy())
def sql(self, **kwargs) -> str:
51    def sql(self, **kwargs) -> str:
52        return self.expression.sql(dialect="spark", **kwargs)
def partitionBy( self, *cols: Union[<MagicMock id='140377707291568'>, List[<MagicMock id='140377707291568'>]]) -> sqlglot.dataframe.sql.WindowSpec:
54    def partitionBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
55        from sqlglot.dataframe.sql.column import Column
56
57        cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols  # type: ignore
58        expressions = [Column.ensure_col(x).expression for x in cols]
59        window_spec = self.copy()
60        partition_by_expressions = window_spec.expression.args.get("partition_by", [])
61        partition_by_expressions.extend(expressions)
62        window_spec.expression.set("partition_by", partition_by_expressions)
63        return window_spec
def orderBy( self, *cols: Union[<MagicMock id='140377707161408'>, List[<MagicMock id='140377707161408'>]]) -> sqlglot.dataframe.sql.WindowSpec:
65    def orderBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
66        from sqlglot.dataframe.sql.column import Column
67
68        cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols  # type: ignore
69        expressions = [Column.ensure_col(x).expression for x in cols]
70        window_spec = self.copy()
71        if window_spec.expression.args.get("order") is None:
72            window_spec.expression.set("order", exp.Order(expressions=[]))
73        order_by = window_spec.expression.args["order"].expressions
74        order_by.extend(expressions)
75        window_spec.expression.args["order"].set("expressions", order_by)
76        return window_spec
def rowsBetween(self, start: int, end: int) -> sqlglot.dataframe.sql.WindowSpec:
111    def rowsBetween(self, start: int, end: int) -> WindowSpec:
112        window_spec = self.copy()
113        spec = self._calc_start_end(start, end)
114        spec["kind"] = "ROWS"
115        window_spec.expression.set(
116            "spec",
117            exp.WindowSpec(
118                **{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec}
119            ),
120        )
121        return window_spec
def rangeBetween(self, start: int, end: int) -> sqlglot.dataframe.sql.WindowSpec:
123    def rangeBetween(self, start: int, end: int) -> WindowSpec:
124        window_spec = self.copy()
125        spec = self._calc_start_end(start, end)
126        spec["kind"] = "RANGE"
127        window_spec.expression.set(
128            "spec",
129            exp.WindowSpec(
130                **{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec}
131            ),
132        )
133        return window_spec
class DataFrameReader:
15class DataFrameReader:
16    def __init__(self, spark: SparkSession):
17        self.spark = spark
18
19    def table(self, tableName: str) -> DataFrame:
20        from sqlglot.dataframe.sql.dataframe import DataFrame
21
22        sqlglot.schema.add_table(tableName)
23
24        return DataFrame(
25            self.spark,
26            exp.Select()
27            .from_(tableName)
28            .select(
29                *(
30                    column if should_identify(column, "safe") else f'"{column}"'
31                    for column in sqlglot.schema.column_names(tableName)
32                )
33            ),
34        )
DataFrameReader(spark: sqlglot.dataframe.sql.SparkSession)
16    def __init__(self, spark: SparkSession):
17        self.spark = spark
def table(self, tableName: str) -> sqlglot.dataframe.sql.DataFrame:
19    def table(self, tableName: str) -> DataFrame:
20        from sqlglot.dataframe.sql.dataframe import DataFrame
21
22        sqlglot.schema.add_table(tableName)
23
24        return DataFrame(
25            self.spark,
26            exp.Select()
27            .from_(tableName)
28            .select(
29                *(
30                    column if should_identify(column, "safe") else f'"{column}"'
31                    for column in sqlglot.schema.column_names(tableName)
32                )
33            ),
34        )
class DataFrameWriter:
37class DataFrameWriter:
38    def __init__(
39        self,
40        df: DataFrame,
41        spark: t.Optional[SparkSession] = None,
42        mode: t.Optional[str] = None,
43        by_name: bool = False,
44    ):
45        self._df = df
46        self._spark = spark or df.spark
47        self._mode = mode
48        self._by_name = by_name
49
50    def copy(self, **kwargs) -> DataFrameWriter:
51        return DataFrameWriter(
52            **{
53                k[1:] if k.startswith("_") else k: v
54                for k, v in object_to_dict(self, **kwargs).items()
55            }
56        )
57
58    def sql(self, **kwargs) -> t.List[str]:
59        return self._df.sql(**kwargs)
60
61    def mode(self, saveMode: t.Optional[str]) -> DataFrameWriter:
62        return self.copy(_mode=saveMode)
63
64    @property
65    def byName(self):
66        return self.copy(by_name=True)
67
68    def insertInto(self, tableName: str, overwrite: t.Optional[bool] = None) -> DataFrameWriter:
69        output_expression_container = exp.Insert(
70            **{
71                "this": exp.to_table(tableName),
72                "overwrite": overwrite,
73            }
74        )
75        df = self._df.copy(output_expression_container=output_expression_container)
76        if self._by_name:
77            columns = sqlglot.schema.column_names(tableName, only_visible=True)
78            df = df._convert_leaf_to_cte().select(*columns)
79
80        return self.copy(_df=df)
81
82    def saveAsTable(self, name: str, format: t.Optional[str] = None, mode: t.Optional[str] = None):
83        if format is not None:
84            raise NotImplementedError("Providing Format in the save as table is not supported")
85        exists, replace, mode = None, None, mode or str(self._mode)
86        if mode == "append":
87            return self.insertInto(name)
88        if mode == "ignore":
89            exists = True
90        if mode == "overwrite":
91            replace = True
92        output_expression_container = exp.Create(
93            this=exp.to_table(name),
94            kind="TABLE",
95            exists=exists,
96            replace=replace,
97        )
98        return self.copy(_df=self._df.copy(output_expression_container=output_expression_container))
DataFrameWriter( df: sqlglot.dataframe.sql.DataFrame, spark: Optional[sqlglot.dataframe.sql.SparkSession] = None, mode: Optional[str] = None, by_name: bool = False)
38    def __init__(
39        self,
40        df: DataFrame,
41        spark: t.Optional[SparkSession] = None,
42        mode: t.Optional[str] = None,
43        by_name: bool = False,
44    ):
45        self._df = df
46        self._spark = spark or df.spark
47        self._mode = mode
48        self._by_name = by_name
def copy(self, **kwargs) -> sqlglot.dataframe.sql.DataFrameWriter:
50    def copy(self, **kwargs) -> DataFrameWriter:
51        return DataFrameWriter(
52            **{
53                k[1:] if k.startswith("_") else k: v
54                for k, v in object_to_dict(self, **kwargs).items()
55            }
56        )
def sql(self, **kwargs) -> List[str]:
58    def sql(self, **kwargs) -> t.List[str]:
59        return self._df.sql(**kwargs)
def mode( self, saveMode: Optional[str]) -> sqlglot.dataframe.sql.DataFrameWriter:
61    def mode(self, saveMode: t.Optional[str]) -> DataFrameWriter:
62        return self.copy(_mode=saveMode)
def insertInto( self, tableName: str, overwrite: Optional[bool] = None) -> sqlglot.dataframe.sql.DataFrameWriter:
68    def insertInto(self, tableName: str, overwrite: t.Optional[bool] = None) -> DataFrameWriter:
69        output_expression_container = exp.Insert(
70            **{
71                "this": exp.to_table(tableName),
72                "overwrite": overwrite,
73            }
74        )
75        df = self._df.copy(output_expression_container=output_expression_container)
76        if self._by_name:
77            columns = sqlglot.schema.column_names(tableName, only_visible=True)
78            df = df._convert_leaf_to_cte().select(*columns)
79
80        return self.copy(_df=df)
def saveAsTable( self, name: str, format: Optional[str] = None, mode: Optional[str] = None):
82    def saveAsTable(self, name: str, format: t.Optional[str] = None, mode: t.Optional[str] = None):
83        if format is not None:
84            raise NotImplementedError("Providing Format in the save as table is not supported")
85        exists, replace, mode = None, None, mode or str(self._mode)
86        if mode == "append":
87            return self.insertInto(name)
88        if mode == "ignore":
89            exists = True
90        if mode == "overwrite":
91            replace = True
92        output_expression_container = exp.Create(
93            this=exp.to_table(name),
94            kind="TABLE",
95            exists=exists,
96            replace=replace,
97        )
98        return self.copy(_df=self._df.copy(output_expression_container=output_expression_container))