Edit on GitHub

sqlglot.dataframe.sql

 1from sqlglot.dataframe.sql.column import Column
 2from sqlglot.dataframe.sql.dataframe import DataFrame, DataFrameNaFunctions
 3from sqlglot.dataframe.sql.group import GroupedData
 4from sqlglot.dataframe.sql.readwriter import DataFrameReader, DataFrameWriter
 5from sqlglot.dataframe.sql.session import SparkSession
 6from sqlglot.dataframe.sql.window import Window, WindowSpec
 7
 8__all__ = [
 9    "SparkSession",
10    "DataFrame",
11    "GroupedData",
12    "Column",
13    "DataFrameNaFunctions",
14    "Window",
15    "WindowSpec",
16    "DataFrameReader",
17    "DataFrameWriter",
18]
class SparkSession:
 20class SparkSession:
 21    known_ids: t.ClassVar[t.Set[str]] = set()
 22    known_branch_ids: t.ClassVar[t.Set[str]] = set()
 23    known_sequence_ids: t.ClassVar[t.Set[str]] = set()
 24    name_to_sequence_id_mapping: t.ClassVar[t.Dict[str, t.List[str]]] = defaultdict(list)
 25
 26    def __init__(self):
 27        self.incrementing_id = 1
 28
 29    def __getattr__(self, name: str) -> SparkSession:
 30        return self
 31
 32    def __call__(self, *args, **kwargs) -> SparkSession:
 33        return self
 34
 35    @property
 36    def read(self) -> DataFrameReader:
 37        return DataFrameReader(self)
 38
 39    def table(self, tableName: str) -> DataFrame:
 40        return self.read.table(tableName)
 41
 42    def createDataFrame(
 43        self,
 44        data: t.Sequence[t.Union[t.Dict[str, ColumnLiterals], t.List[ColumnLiterals], t.Tuple]],
 45        schema: t.Optional[SchemaInput] = None,
 46        samplingRatio: t.Optional[float] = None,
 47        verifySchema: bool = False,
 48    ) -> DataFrame:
 49        from sqlglot.dataframe.sql.dataframe import DataFrame
 50
 51        if samplingRatio is not None or verifySchema:
 52            raise NotImplementedError("Sampling Ratio and Verify Schema are not supported")
 53        if schema is not None and (
 54            not isinstance(schema, (StructType, str, list))
 55            or (isinstance(schema, list) and not isinstance(schema[0], str))
 56        ):
 57            raise NotImplementedError("Only schema of either list or string of list supported")
 58        if not data:
 59            raise ValueError("Must provide data to create into a DataFrame")
 60
 61        column_mapping: t.Dict[str, t.Optional[str]]
 62        if schema is not None:
 63            column_mapping = get_column_mapping_from_schema_input(schema)
 64        elif isinstance(data[0], dict):
 65            column_mapping = {col_name.strip(): None for col_name in data[0]}
 66        else:
 67            column_mapping = {f"_{i}": None for i in range(1, len(data[0]) + 1)}
 68
 69        data_expressions = [
 70            exp.Tuple(
 71                expressions=list(
 72                    map(
 73                        lambda x: F.lit(x).expression,
 74                        row if not isinstance(row, dict) else row.values(),
 75                    )
 76                )
 77            )
 78            for row in data
 79        ]
 80
 81        sel_columns = [
 82            F.col(name).cast(data_type).alias(name).expression
 83            if data_type is not None
 84            else F.col(name).expression
 85            for name, data_type in column_mapping.items()
 86        ]
 87
 88        select_kwargs = {
 89            "expressions": sel_columns,
 90            "from": exp.From(
 91                expressions=[
 92                    exp.Values(
 93                        expressions=data_expressions,
 94                        alias=exp.TableAlias(
 95                            this=exp.to_identifier(self._auto_incrementing_name),
 96                            columns=[exp.to_identifier(col_name) for col_name in column_mapping],
 97                        ),
 98                    ),
 99                ],
100            ),
101        }
102
103        sel_expression = exp.Select(**select_kwargs)
104        return DataFrame(self, sel_expression)
105
106    def sql(self, sqlQuery: str) -> DataFrame:
107        expression = sqlglot.parse_one(sqlQuery, read="spark")
108        if isinstance(expression, exp.Select):
109            df = DataFrame(self, expression)
110            df = df._convert_leaf_to_cte()
111        elif isinstance(expression, (exp.Create, exp.Insert)):
112            select_expression = expression.expression.copy()
113            if isinstance(expression, exp.Insert):
114                select_expression.set("with", expression.args.get("with"))
115                expression.set("with", None)
116            del expression.args["expression"]
117            df = DataFrame(self, select_expression, output_expression_container=expression)  # type: ignore
118            df = df._convert_leaf_to_cte()
119        else:
120            raise ValueError(
121                "Unknown expression type provided in the SQL. Please create an issue with the SQL."
122            )
123        return df
124
125    @property
126    def _auto_incrementing_name(self) -> str:
127        name = f"a{self.incrementing_id}"
128        self.incrementing_id += 1
129        return name
130
131    @property
132    def _random_name(self) -> str:
133        return "r" + uuid.uuid4().hex
134
135    @property
136    def _random_branch_id(self) -> str:
137        id = self._random_id
138        self.known_branch_ids.add(id)
139        return id
140
141    @property
142    def _random_sequence_id(self):
143        id = self._random_id
144        self.known_sequence_ids.add(id)
145        return id
146
147    @property
148    def _random_id(self) -> str:
149        id = self._random_name
150        self.known_ids.add(id)
151        return id
152
153    @property
154    def _join_hint_names(self) -> t.Set[str]:
155        return {"BROADCAST", "MERGE", "SHUFFLE_HASH", "SHUFFLE_REPLICATE_NL"}
156
157    def _add_alias_to_mapping(self, name: str, sequence_id: str):
158        self.name_to_sequence_id_mapping[name].append(sequence_id)
def table(self, tableName: str) -> sqlglot.dataframe.sql.DataFrame:
39    def table(self, tableName: str) -> DataFrame:
40        return self.read.table(tableName)
def createDataFrame( self, data: Sequence[Union[Dict[str, <MagicMock id='139719846869552'>], List[<MagicMock id='139719846869552'>], Tuple]], schema: Optional[<MagicMock id='139719847382416'>] = None, samplingRatio: Optional[float] = None, verifySchema: bool = False) -> sqlglot.dataframe.sql.DataFrame:
 42    def createDataFrame(
 43        self,
 44        data: t.Sequence[t.Union[t.Dict[str, ColumnLiterals], t.List[ColumnLiterals], t.Tuple]],
 45        schema: t.Optional[SchemaInput] = None,
 46        samplingRatio: t.Optional[float] = None,
 47        verifySchema: bool = False,
 48    ) -> DataFrame:
 49        from sqlglot.dataframe.sql.dataframe import DataFrame
 50
 51        if samplingRatio is not None or verifySchema:
 52            raise NotImplementedError("Sampling Ratio and Verify Schema are not supported")
 53        if schema is not None and (
 54            not isinstance(schema, (StructType, str, list))
 55            or (isinstance(schema, list) and not isinstance(schema[0], str))
 56        ):
 57            raise NotImplementedError("Only schema of either list or string of list supported")
 58        if not data:
 59            raise ValueError("Must provide data to create into a DataFrame")
 60
 61        column_mapping: t.Dict[str, t.Optional[str]]
 62        if schema is not None:
 63            column_mapping = get_column_mapping_from_schema_input(schema)
 64        elif isinstance(data[0], dict):
 65            column_mapping = {col_name.strip(): None for col_name in data[0]}
 66        else:
 67            column_mapping = {f"_{i}": None for i in range(1, len(data[0]) + 1)}
 68
 69        data_expressions = [
 70            exp.Tuple(
 71                expressions=list(
 72                    map(
 73                        lambda x: F.lit(x).expression,
 74                        row if not isinstance(row, dict) else row.values(),
 75                    )
 76                )
 77            )
 78            for row in data
 79        ]
 80
 81        sel_columns = [
 82            F.col(name).cast(data_type).alias(name).expression
 83            if data_type is not None
 84            else F.col(name).expression
 85            for name, data_type in column_mapping.items()
 86        ]
 87
 88        select_kwargs = {
 89            "expressions": sel_columns,
 90            "from": exp.From(
 91                expressions=[
 92                    exp.Values(
 93                        expressions=data_expressions,
 94                        alias=exp.TableAlias(
 95                            this=exp.to_identifier(self._auto_incrementing_name),
 96                            columns=[exp.to_identifier(col_name) for col_name in column_mapping],
 97                        ),
 98                    ),
 99                ],
100            ),
101        }
102
103        sel_expression = exp.Select(**select_kwargs)
104        return DataFrame(self, sel_expression)
def sql(self, sqlQuery: str) -> sqlglot.dataframe.sql.DataFrame:
106    def sql(self, sqlQuery: str) -> DataFrame:
107        expression = sqlglot.parse_one(sqlQuery, read="spark")
108        if isinstance(expression, exp.Select):
109            df = DataFrame(self, expression)
110            df = df._convert_leaf_to_cte()
111        elif isinstance(expression, (exp.Create, exp.Insert)):
112            select_expression = expression.expression.copy()
113            if isinstance(expression, exp.Insert):
114                select_expression.set("with", expression.args.get("with"))
115                expression.set("with", None)
116            del expression.args["expression"]
117            df = DataFrame(self, select_expression, output_expression_container=expression)  # type: ignore
118            df = df._convert_leaf_to_cte()
119        else:
120            raise ValueError(
121                "Unknown expression type provided in the SQL. Please create an issue with the SQL."
122            )
123        return df
class DataFrame:
 45class DataFrame:
 46    def __init__(
 47        self,
 48        spark: SparkSession,
 49        expression: exp.Select,
 50        branch_id: t.Optional[str] = None,
 51        sequence_id: t.Optional[str] = None,
 52        last_op: Operation = Operation.INIT,
 53        pending_hints: t.Optional[t.List[exp.Expression]] = None,
 54        output_expression_container: t.Optional[OutputExpressionContainer] = None,
 55        **kwargs,
 56    ):
 57        self.spark = spark
 58        self.expression = expression
 59        self.branch_id = branch_id or self.spark._random_branch_id
 60        self.sequence_id = sequence_id or self.spark._random_sequence_id
 61        self.last_op = last_op
 62        self.pending_hints = pending_hints or []
 63        self.output_expression_container = output_expression_container or exp.Select()
 64
 65    def __getattr__(self, column_name: str) -> Column:
 66        return self[column_name]
 67
 68    def __getitem__(self, column_name: str) -> Column:
 69        column_name = f"{self.branch_id}.{column_name}"
 70        return Column(column_name)
 71
 72    def __copy__(self):
 73        return self.copy()
 74
 75    @property
 76    def sparkSession(self):
 77        return self.spark
 78
 79    @property
 80    def write(self):
 81        return DataFrameWriter(self)
 82
 83    @property
 84    def latest_cte_name(self) -> str:
 85        if not self.expression.ctes:
 86            from_exp = self.expression.args["from"]
 87            if from_exp.alias_or_name:
 88                return from_exp.alias_or_name
 89            table_alias = from_exp.find(exp.TableAlias)
 90            if not table_alias:
 91                raise RuntimeError(
 92                    f"Could not find an alias name for this expression: {self.expression}"
 93                )
 94            return table_alias.alias_or_name
 95        return self.expression.ctes[-1].alias
 96
 97    @property
 98    def pending_join_hints(self):
 99        return [hint for hint in self.pending_hints if isinstance(hint, exp.JoinHint)]
100
101    @property
102    def pending_partition_hints(self):
103        return [hint for hint in self.pending_hints if isinstance(hint, exp.Anonymous)]
104
105    @property
106    def columns(self) -> t.List[str]:
107        return self.expression.named_selects
108
109    @property
110    def na(self) -> DataFrameNaFunctions:
111        return DataFrameNaFunctions(self)
112
113    def _replace_cte_names_with_hashes(self, expression: exp.Select):
114        replacement_mapping = {}
115        for cte in expression.ctes:
116            old_name_id = cte.args["alias"].this
117            new_hashed_id = exp.to_identifier(
118                self._create_hash_from_expression(cte.this), quoted=old_name_id.args["quoted"]
119            )
120            replacement_mapping[old_name_id] = new_hashed_id
121            expression = expression.transform(replace_id_value, replacement_mapping)
122        return expression
123
124    def _create_cte_from_expression(
125        self,
126        expression: exp.Expression,
127        branch_id: t.Optional[str] = None,
128        sequence_id: t.Optional[str] = None,
129        **kwargs,
130    ) -> t.Tuple[exp.CTE, str]:
131        name = self.spark._random_name
132        expression_to_cte = expression.copy()
133        expression_to_cte.set("with", None)
134        cte = exp.Select().with_(name, as_=expression_to_cte, **kwargs).ctes[0]
135        cte.set("branch_id", branch_id or self.branch_id)
136        cte.set("sequence_id", sequence_id or self.sequence_id)
137        return cte, name
138
139    @t.overload
140    def _ensure_list_of_columns(self, cols: t.Collection[ColumnOrLiteral]) -> t.List[Column]:
141        ...
142
143    @t.overload
144    def _ensure_list_of_columns(self, cols: ColumnOrLiteral) -> t.List[Column]:
145        ...
146
147    def _ensure_list_of_columns(self, cols):
148        return Column.ensure_cols(ensure_list(cols))
149
150    def _ensure_and_normalize_cols(self, cols):
151        cols = self._ensure_list_of_columns(cols)
152        normalize(self.spark, self.expression, cols)
153        return cols
154
155    def _ensure_and_normalize_col(self, col):
156        col = Column.ensure_col(col)
157        normalize(self.spark, self.expression, col)
158        return col
159
160    def _convert_leaf_to_cte(self, sequence_id: t.Optional[str] = None) -> DataFrame:
161        df = self._resolve_pending_hints()
162        sequence_id = sequence_id or df.sequence_id
163        expression = df.expression.copy()
164        cte_expression, cte_name = df._create_cte_from_expression(
165            expression=expression, sequence_id=sequence_id
166        )
167        new_expression = df._add_ctes_to_expression(
168            exp.Select(), expression.ctes + [cte_expression]
169        )
170        sel_columns = df._get_outer_select_columns(cte_expression)
171        new_expression = new_expression.from_(cte_name).select(
172            *[x.alias_or_name for x in sel_columns]
173        )
174        return df.copy(expression=new_expression, sequence_id=sequence_id)
175
176    def _resolve_pending_hints(self) -> DataFrame:
177        df = self.copy()
178        if not self.pending_hints:
179            return df
180        expression = df.expression
181        hint_expression = expression.args.get("hint") or exp.Hint(expressions=[])
182        for hint in df.pending_partition_hints:
183            hint_expression.append("expressions", hint)
184            df.pending_hints.remove(hint)
185
186        join_aliases = {
187            join_table.alias_or_name
188            for join_table in get_tables_from_expression_with_join(expression)
189        }
190        if join_aliases:
191            for hint in df.pending_join_hints:
192                for sequence_id_expression in hint.expressions:
193                    sequence_id_or_name = sequence_id_expression.alias_or_name
194                    sequence_ids_to_match = [sequence_id_or_name]
195                    if sequence_id_or_name in df.spark.name_to_sequence_id_mapping:
196                        sequence_ids_to_match = df.spark.name_to_sequence_id_mapping[
197                            sequence_id_or_name
198                        ]
199                    matching_ctes = [
200                        cte
201                        for cte in reversed(expression.ctes)
202                        if cte.args["sequence_id"] in sequence_ids_to_match
203                    ]
204                    for matching_cte in matching_ctes:
205                        if matching_cte.alias_or_name in join_aliases:
206                            sequence_id_expression.set("this", matching_cte.args["alias"].this)
207                            df.pending_hints.remove(hint)
208                            break
209                hint_expression.append("expressions", hint)
210        if hint_expression.expressions:
211            expression.set("hint", hint_expression)
212        return df
213
214    def _hint(self, hint_name: str, args: t.List[Column]) -> DataFrame:
215        hint_name = hint_name.upper()
216        hint_expression = (
217            exp.JoinHint(
218                this=hint_name,
219                expressions=[exp.to_table(parameter.alias_or_name) for parameter in args],
220            )
221            if hint_name in JOIN_HINTS
222            else exp.Anonymous(
223                this=hint_name, expressions=[parameter.expression for parameter in args]
224            )
225        )
226        new_df = self.copy()
227        new_df.pending_hints.append(hint_expression)
228        return new_df
229
230    def _set_operation(self, klass: t.Callable, other: DataFrame, distinct: bool):
231        other_df = other._convert_leaf_to_cte()
232        base_expression = self.expression.copy()
233        base_expression = self._add_ctes_to_expression(base_expression, other_df.expression.ctes)
234        all_ctes = base_expression.ctes
235        other_df.expression.set("with", None)
236        base_expression.set("with", None)
237        operation = klass(this=base_expression, distinct=distinct, expression=other_df.expression)
238        operation.set("with", exp.With(expressions=all_ctes))
239        return self.copy(expression=operation)._convert_leaf_to_cte()
240
241    def _cache(self, storage_level: str):
242        df = self._convert_leaf_to_cte()
243        df.expression.ctes[-1].set("cache_storage_level", storage_level)
244        return df
245
246    @classmethod
247    def _add_ctes_to_expression(cls, expression: exp.Select, ctes: t.List[exp.CTE]) -> exp.Select:
248        expression = expression.copy()
249        with_expression = expression.args.get("with")
250        if with_expression:
251            existing_ctes = with_expression.expressions
252            existsing_cte_names = {x.alias_or_name for x in existing_ctes}
253            for cte in ctes:
254                if cte.alias_or_name not in existsing_cte_names:
255                    existing_ctes.append(cte)
256        else:
257            existing_ctes = ctes
258        expression.set("with", exp.With(expressions=existing_ctes))
259        return expression
260
261    @classmethod
262    def _get_outer_select_columns(cls, item: t.Union[exp.Expression, DataFrame]) -> t.List[Column]:
263        expression = item.expression if isinstance(item, DataFrame) else item
264        return [Column(x) for x in (expression.find(exp.Select) or exp.Select()).expressions]
265
266    @classmethod
267    def _create_hash_from_expression(cls, expression: exp.Select):
268        value = expression.sql(dialect="spark").encode("utf-8")
269        return f"t{zlib.crc32(value)}"[:6]
270
271    def _get_select_expressions(
272        self,
273    ) -> t.List[t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select]]:
274        select_expressions: t.List[
275            t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select]
276        ] = []
277        main_select_ctes: t.List[exp.CTE] = []
278        for cte in self.expression.ctes:
279            cache_storage_level = cte.args.get("cache_storage_level")
280            if cache_storage_level:
281                select_expression = cte.this.copy()
282                select_expression.set("with", exp.With(expressions=copy(main_select_ctes)))
283                select_expression.set("cte_alias_name", cte.alias_or_name)
284                select_expression.set("cache_storage_level", cache_storage_level)
285                select_expressions.append((exp.Cache, select_expression))
286            else:
287                main_select_ctes.append(cte)
288        main_select = self.expression.copy()
289        if main_select_ctes:
290            main_select.set("with", exp.With(expressions=main_select_ctes))
291        expression_select_pair = (type(self.output_expression_container), main_select)
292        select_expressions.append(expression_select_pair)  # type: ignore
293        return select_expressions
294
295    def sql(self, dialect="spark", optimize=True, **kwargs) -> t.List[str]:
296        df = self._resolve_pending_hints()
297        select_expressions = df._get_select_expressions()
298        output_expressions: t.List[t.Union[exp.Select, exp.Cache, exp.Drop]] = []
299        replacement_mapping: t.Dict[exp.Identifier, exp.Identifier] = {}
300        for expression_type, select_expression in select_expressions:
301            select_expression = select_expression.transform(replace_id_value, replacement_mapping)
302            if optimize:
303                select_expression = optimize_func(select_expression, identify="always")
304            select_expression = df._replace_cte_names_with_hashes(select_expression)
305            expression: t.Union[exp.Select, exp.Cache, exp.Drop]
306            if expression_type == exp.Cache:
307                cache_table_name = df._create_hash_from_expression(select_expression)
308                cache_table = exp.to_table(cache_table_name)
309                original_alias_name = select_expression.args["cte_alias_name"]
310
311                replacement_mapping[exp.to_identifier(original_alias_name)] = exp.to_identifier(  # type: ignore
312                    cache_table_name
313                )
314                sqlglot.schema.add_table(
315                    cache_table_name,
316                    {
317                        expression.alias_or_name: expression.type.sql("spark")
318                        for expression in select_expression.expressions
319                    },
320                )
321                cache_storage_level = select_expression.args["cache_storage_level"]
322                options = [
323                    exp.Literal.string("storageLevel"),
324                    exp.Literal.string(cache_storage_level),
325                ]
326                expression = exp.Cache(
327                    this=cache_table, expression=select_expression, lazy=True, options=options
328                )
329                # We will drop the "view" if it exists before running the cache table
330                output_expressions.append(exp.Drop(this=cache_table, exists=True, kind="VIEW"))
331            elif expression_type == exp.Create:
332                expression = df.output_expression_container.copy()
333                expression.set("expression", select_expression)
334            elif expression_type == exp.Insert:
335                expression = df.output_expression_container.copy()
336                select_without_ctes = select_expression.copy()
337                select_without_ctes.set("with", None)
338                expression.set("expression", select_without_ctes)
339                if select_expression.ctes:
340                    expression.set("with", exp.With(expressions=select_expression.ctes))
341            elif expression_type == exp.Select:
342                expression = select_expression
343            else:
344                raise ValueError(f"Invalid expression type: {expression_type}")
345            output_expressions.append(expression)
346
347        return [
348            expression.sql(**{"dialect": dialect, **kwargs}) for expression in output_expressions
349        ]
350
351    def copy(self, **kwargs) -> DataFrame:
352        return DataFrame(**object_to_dict(self, **kwargs))
353
354    @operation(Operation.SELECT)
355    def select(self, *cols, **kwargs) -> DataFrame:
356        cols = self._ensure_and_normalize_cols(cols)
357        kwargs["append"] = kwargs.get("append", False)
358        if self.expression.args.get("joins"):
359            ambiguous_cols = [col for col in cols if not col.column_expression.table]
360            if ambiguous_cols:
361                join_table_identifiers = [
362                    x.this for x in get_tables_from_expression_with_join(self.expression)
363                ]
364                cte_names_in_join = [x.this for x in join_table_identifiers]
365                for ambiguous_col in ambiguous_cols:
366                    ctes_with_column = [
367                        cte
368                        for cte in self.expression.ctes
369                        if cte.alias_or_name in cte_names_in_join
370                        and ambiguous_col.alias_or_name in cte.this.named_selects
371                    ]
372                    # If the select column does not specify a table and there is a join
373                    # then we assume they are referring to the left table
374                    if len(ctes_with_column) > 1:
375                        table_identifier = self.expression.args["from"].args["expressions"][0].this
376                    else:
377                        table_identifier = ctes_with_column[0].args["alias"].this
378                    ambiguous_col.expression.set("table", table_identifier)
379        return self.copy(
380            expression=self.expression.select(*[x.expression for x in cols], **kwargs), **kwargs
381        )
382
383    @operation(Operation.NO_OP)
384    def alias(self, name: str, **kwargs) -> DataFrame:
385        new_sequence_id = self.spark._random_sequence_id
386        df = self.copy()
387        for join_hint in df.pending_join_hints:
388            for expression in join_hint.expressions:
389                if expression.alias_or_name == self.sequence_id:
390                    expression.set("this", Column.ensure_col(new_sequence_id).expression)
391        df.spark._add_alias_to_mapping(name, new_sequence_id)
392        return df._convert_leaf_to_cte(sequence_id=new_sequence_id)
393
394    @operation(Operation.WHERE)
395    def where(self, column: t.Union[Column, bool], **kwargs) -> DataFrame:
396        col = self._ensure_and_normalize_col(column)
397        return self.copy(expression=self.expression.where(col.expression))
398
399    filter = where
400
401    @operation(Operation.GROUP_BY)
402    def groupBy(self, *cols, **kwargs) -> GroupedData:
403        columns = self._ensure_and_normalize_cols(cols)
404        return GroupedData(self, columns, self.last_op)
405
406    @operation(Operation.SELECT)
407    def agg(self, *exprs, **kwargs) -> DataFrame:
408        cols = self._ensure_and_normalize_cols(exprs)
409        return self.groupBy().agg(*cols)
410
411    @operation(Operation.FROM)
412    def join(
413        self,
414        other_df: DataFrame,
415        on: t.Union[str, t.List[str], Column, t.List[Column]],
416        how: str = "inner",
417        **kwargs,
418    ) -> DataFrame:
419        other_df = other_df._convert_leaf_to_cte()
420        pre_join_self_latest_cte_name = self.latest_cte_name
421        columns = self._ensure_and_normalize_cols(on)
422        join_type = how.replace("_", " ")
423        if isinstance(columns[0].expression, exp.Column):
424            join_columns = [
425                Column(x).set_table_name(pre_join_self_latest_cte_name) for x in columns
426            ]
427            join_clause = functools.reduce(
428                lambda x, y: x & y,
429                [
430                    col.copy().set_table_name(pre_join_self_latest_cte_name)
431                    == col.copy().set_table_name(other_df.latest_cte_name)
432                    for col in columns
433                ],
434            )
435        else:
436            if len(columns) > 1:
437                columns = [functools.reduce(lambda x, y: x & y, columns)]
438            join_clause = columns[0]
439            join_columns = [
440                Column(x).set_table_name(pre_join_self_latest_cte_name)
441                if i % 2 == 0
442                else Column(x).set_table_name(other_df.latest_cte_name)
443                for i, x in enumerate(join_clause.expression.find_all(exp.Column))
444            ]
445        self_columns = [
446            column.set_table_name(pre_join_self_latest_cte_name, copy=True)
447            for column in self._get_outer_select_columns(self)
448        ]
449        other_columns = [
450            column.set_table_name(other_df.latest_cte_name, copy=True)
451            for column in self._get_outer_select_columns(other_df)
452        ]
453        column_value_mapping = {
454            column.alias_or_name
455            if not isinstance(column.expression.this, exp.Star)
456            else column.sql(): column
457            for column in other_columns + self_columns + join_columns
458        }
459        all_columns = [
460            column_value_mapping[name]
461            for name in {x.alias_or_name: None for x in join_columns + self_columns + other_columns}
462        ]
463        new_df = self.copy(
464            expression=self.expression.join(
465                other_df.latest_cte_name, on=join_clause.expression, join_type=join_type
466            )
467        )
468        new_df.expression = new_df._add_ctes_to_expression(
469            new_df.expression, other_df.expression.ctes
470        )
471        new_df.pending_hints.extend(other_df.pending_hints)
472        new_df = new_df.select.__wrapped__(new_df, *all_columns)
473        return new_df
474
475    @operation(Operation.ORDER_BY)
476    def orderBy(
477        self,
478        *cols: t.Union[str, Column],
479        ascending: t.Optional[t.Union[t.Any, t.List[t.Any]]] = None,
480    ) -> DataFrame:
481        """
482        This implementation lets any ordered columns take priority over whatever is provided in `ascending`. Spark
483        has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this
484        is unlikely to come up.
485        """
486        columns = self._ensure_and_normalize_cols(cols)
487        pre_ordered_col_indexes = [
488            x
489            for x in [
490                i if isinstance(col.expression, exp.Ordered) else None
491                for i, col in enumerate(columns)
492            ]
493            if x is not None
494        ]
495        if ascending is None:
496            ascending = [True] * len(columns)
497        elif not isinstance(ascending, list):
498            ascending = [ascending] * len(columns)
499        ascending = [bool(x) for i, x in enumerate(ascending)]
500        assert len(columns) == len(
501            ascending
502        ), "The length of items in ascending must equal the number of columns provided"
503        col_and_ascending = list(zip(columns, ascending))
504        order_by_columns = [
505            exp.Ordered(this=col.expression, desc=not asc)
506            if i not in pre_ordered_col_indexes
507            else columns[i].column_expression
508            for i, (col, asc) in enumerate(col_and_ascending)
509        ]
510        return self.copy(expression=self.expression.order_by(*order_by_columns))
511
512    sort = orderBy
513
514    @operation(Operation.FROM)
515    def union(self, other: DataFrame) -> DataFrame:
516        return self._set_operation(exp.Union, other, False)
517
518    unionAll = union
519
520    @operation(Operation.FROM)
521    def unionByName(self, other: DataFrame, allowMissingColumns: bool = False):
522        l_columns = self.columns
523        r_columns = other.columns
524        if not allowMissingColumns:
525            l_expressions = l_columns
526            r_expressions = l_columns
527        else:
528            l_expressions = []
529            r_expressions = []
530            r_columns_unused = copy(r_columns)
531            for l_column in l_columns:
532                l_expressions.append(l_column)
533                if l_column in r_columns:
534                    r_expressions.append(l_column)
535                    r_columns_unused.remove(l_column)
536                else:
537                    r_expressions.append(exp.alias_(exp.Null(), l_column))
538            for r_column in r_columns_unused:
539                l_expressions.append(exp.alias_(exp.Null(), r_column))
540                r_expressions.append(r_column)
541        r_df = (
542            other.copy()._convert_leaf_to_cte().select(*self._ensure_list_of_columns(r_expressions))
543        )
544        l_df = self.copy()
545        if allowMissingColumns:
546            l_df = l_df._convert_leaf_to_cte().select(*self._ensure_list_of_columns(l_expressions))
547        return l_df._set_operation(exp.Union, r_df, False)
548
549    @operation(Operation.FROM)
550    def intersect(self, other: DataFrame) -> DataFrame:
551        return self._set_operation(exp.Intersect, other, True)
552
553    @operation(Operation.FROM)
554    def intersectAll(self, other: DataFrame) -> DataFrame:
555        return self._set_operation(exp.Intersect, other, False)
556
557    @operation(Operation.FROM)
558    def exceptAll(self, other: DataFrame) -> DataFrame:
559        return self._set_operation(exp.Except, other, False)
560
561    @operation(Operation.SELECT)
562    def distinct(self) -> DataFrame:
563        return self.copy(expression=self.expression.distinct())
564
565    @operation(Operation.SELECT)
566    def dropDuplicates(self, subset: t.Optional[t.List[str]] = None):
567        if not subset:
568            return self.distinct()
569        column_names = ensure_list(subset)
570        window = Window.partitionBy(*column_names).orderBy(*column_names)
571        return (
572            self.copy()
573            .withColumn("row_num", F.row_number().over(window))
574            .where(F.col("row_num") == F.lit(1))
575            .drop("row_num")
576        )
577
578    @operation(Operation.FROM)
579    def dropna(
580        self,
581        how: str = "any",
582        thresh: t.Optional[int] = None,
583        subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
584    ) -> DataFrame:
585        minimum_non_null = thresh or 0  # will be determined later if thresh is null
586        new_df = self.copy()
587        all_columns = self._get_outer_select_columns(new_df.expression)
588        if subset:
589            null_check_columns = self._ensure_and_normalize_cols(subset)
590        else:
591            null_check_columns = all_columns
592        if thresh is None:
593            minimum_num_nulls = 1 if how == "any" else len(null_check_columns)
594        else:
595            minimum_num_nulls = len(null_check_columns) - minimum_non_null + 1
596        if minimum_num_nulls > len(null_check_columns):
597            raise RuntimeError(
598                f"The minimum num nulls for dropna must be less than or equal to the number of columns. "
599                f"Minimum num nulls: {minimum_num_nulls}, Num Columns: {len(null_check_columns)}"
600            )
601        if_null_checks = [
602            F.when(column.isNull(), F.lit(1)).otherwise(F.lit(0)) for column in null_check_columns
603        ]
604        nulls_added_together = functools.reduce(lambda x, y: x + y, if_null_checks)
605        num_nulls = nulls_added_together.alias("num_nulls")
606        new_df = new_df.select(num_nulls, append=True)
607        filtered_df = new_df.where(F.col("num_nulls") < F.lit(minimum_num_nulls))
608        final_df = filtered_df.select(*all_columns)
609        return final_df
610
611    @operation(Operation.FROM)
612    def fillna(
613        self,
614        value: t.Union[ColumnLiterals],
615        subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
616    ) -> DataFrame:
617        """
618        Functionality Difference: If you provide a value to replace a null and that type conflicts
619        with the type of the column then PySpark will just ignore your replacement.
620        This will try to cast them to be the same in some cases. So they won't always match.
621        Best to not mix types so make sure replacement is the same type as the column
622
623        Possibility for improvement: Use `typeof` function to get the type of the column
624        and check if it matches the type of the value provided. If not then make it null.
625        """
626        from sqlglot.dataframe.sql.functions import lit
627
628        values = None
629        columns = None
630        new_df = self.copy()
631        all_columns = self._get_outer_select_columns(new_df.expression)
632        all_column_mapping = {column.alias_or_name: column for column in all_columns}
633        if isinstance(value, dict):
634            values = list(value.values())
635            columns = self._ensure_and_normalize_cols(list(value))
636        if not columns:
637            columns = self._ensure_and_normalize_cols(subset) if subset else all_columns
638        if not values:
639            values = [value] * len(columns)
640        value_columns = [lit(value) for value in values]
641
642        null_replacement_mapping = {
643            column.alias_or_name: (
644                F.when(column.isNull(), value).otherwise(column).alias(column.alias_or_name)
645            )
646            for column, value in zip(columns, value_columns)
647        }
648        null_replacement_mapping = {**all_column_mapping, **null_replacement_mapping}
649        null_replacement_columns = [
650            null_replacement_mapping[column.alias_or_name] for column in all_columns
651        ]
652        new_df = new_df.select(*null_replacement_columns)
653        return new_df
654
655    @operation(Operation.FROM)
656    def replace(
657        self,
658        to_replace: t.Union[bool, int, float, str, t.List, t.Dict],
659        value: t.Optional[t.Union[bool, int, float, str, t.List]] = None,
660        subset: t.Optional[t.Collection[ColumnOrName] | ColumnOrName] = None,
661    ) -> DataFrame:
662        from sqlglot.dataframe.sql.functions import lit
663
664        old_values = None
665        new_df = self.copy()
666        all_columns = self._get_outer_select_columns(new_df.expression)
667        all_column_mapping = {column.alias_or_name: column for column in all_columns}
668
669        columns = self._ensure_and_normalize_cols(subset) if subset else all_columns
670        if isinstance(to_replace, dict):
671            old_values = list(to_replace)
672            new_values = list(to_replace.values())
673        elif not old_values and isinstance(to_replace, list):
674            assert isinstance(value, list), "value must be a list since the replacements are a list"
675            assert len(to_replace) == len(
676                value
677            ), "the replacements and values must be the same length"
678            old_values = to_replace
679            new_values = value
680        else:
681            old_values = [to_replace] * len(columns)
682            new_values = [value] * len(columns)
683        old_values = [lit(value) for value in old_values]
684        new_values = [lit(value) for value in new_values]
685
686        replacement_mapping = {}
687        for column in columns:
688            expression = Column(None)
689            for i, (old_value, new_value) in enumerate(zip(old_values, new_values)):
690                if i == 0:
691                    expression = F.when(column == old_value, new_value)
692                else:
693                    expression = expression.when(column == old_value, new_value)  # type: ignore
694            replacement_mapping[column.alias_or_name] = expression.otherwise(column).alias(
695                column.expression.alias_or_name
696            )
697
698        replacement_mapping = {**all_column_mapping, **replacement_mapping}
699        replacement_columns = [replacement_mapping[column.alias_or_name] for column in all_columns]
700        new_df = new_df.select(*replacement_columns)
701        return new_df
702
703    @operation(Operation.SELECT)
704    def withColumn(self, colName: str, col: Column) -> DataFrame:
705        col = self._ensure_and_normalize_col(col)
706        existing_col_names = self.expression.named_selects
707        existing_col_index = (
708            existing_col_names.index(colName) if colName in existing_col_names else None
709        )
710        if existing_col_index:
711            expression = self.expression.copy()
712            expression.expressions[existing_col_index] = col.expression
713            return self.copy(expression=expression)
714        return self.copy().select(col.alias(colName), append=True)
715
716    @operation(Operation.SELECT)
717    def withColumnRenamed(self, existing: str, new: str):
718        expression = self.expression.copy()
719        existing_columns = [
720            expression
721            for expression in expression.expressions
722            if expression.alias_or_name == existing
723        ]
724        if not existing_columns:
725            raise ValueError("Tried to rename a column that doesn't exist")
726        for existing_column in existing_columns:
727            if isinstance(existing_column, exp.Column):
728                existing_column.replace(exp.alias_(existing_column.copy(), new))
729            else:
730                existing_column.set("alias", exp.to_identifier(new))
731        return self.copy(expression=expression)
732
733    @operation(Operation.SELECT)
734    def drop(self, *cols: t.Union[str, Column]) -> DataFrame:
735        all_columns = self._get_outer_select_columns(self.expression)
736        drop_cols = self._ensure_and_normalize_cols(cols)
737        new_columns = [
738            col
739            for col in all_columns
740            if col.alias_or_name not in [drop_column.alias_or_name for drop_column in drop_cols]
741        ]
742        return self.copy().select(*new_columns, append=False)
743
744    @operation(Operation.LIMIT)
745    def limit(self, num: int) -> DataFrame:
746        return self.copy(expression=self.expression.limit(num))
747
748    @operation(Operation.NO_OP)
749    def hint(self, name: str, *parameters: t.Optional[t.Union[str, int]]) -> DataFrame:
750        parameter_list = ensure_list(parameters)
751        parameter_columns = (
752            self._ensure_list_of_columns(parameter_list)
753            if parameters
754            else Column.ensure_cols([self.sequence_id])
755        )
756        return self._hint(name, parameter_columns)
757
758    @operation(Operation.NO_OP)
759    def repartition(
760        self, numPartitions: t.Union[int, ColumnOrName], *cols: ColumnOrName
761    ) -> DataFrame:
762        num_partition_cols = self._ensure_list_of_columns(numPartitions)
763        columns = self._ensure_and_normalize_cols(cols)
764        args = num_partition_cols + columns
765        return self._hint("repartition", args)
766
767    @operation(Operation.NO_OP)
768    def coalesce(self, numPartitions: int) -> DataFrame:
769        num_partitions = Column.ensure_cols([numPartitions])
770        return self._hint("coalesce", num_partitions)
771
772    @operation(Operation.NO_OP)
773    def cache(self) -> DataFrame:
774        return self._cache(storage_level="MEMORY_AND_DISK")
775
776    @operation(Operation.NO_OP)
777    def persist(self, storageLevel: str = "MEMORY_AND_DISK_SER") -> DataFrame:
778        """
779        Storage Level Options: https://spark.apache.org/docs/3.0.0-preview/sql-ref-syntax-aux-cache-cache-table.html
780        """
781        return self._cache(storageLevel)
DataFrame( spark: <MagicMock id='139719851707024'>, expression: sqlglot.expressions.Select, branch_id: Optional[str] = None, sequence_id: Optional[str] = None, last_op: sqlglot.dataframe.sql.operations.Operation = <Operation.INIT: -1>, pending_hints: Optional[List[sqlglot.expressions.Expression]] = None, output_expression_container: Optional[<MagicMock id='139719851535440'>] = None, **kwargs)
46    def __init__(
47        self,
48        spark: SparkSession,
49        expression: exp.Select,
50        branch_id: t.Optional[str] = None,
51        sequence_id: t.Optional[str] = None,
52        last_op: Operation = Operation.INIT,
53        pending_hints: t.Optional[t.List[exp.Expression]] = None,
54        output_expression_container: t.Optional[OutputExpressionContainer] = None,
55        **kwargs,
56    ):
57        self.spark = spark
58        self.expression = expression
59        self.branch_id = branch_id or self.spark._random_branch_id
60        self.sequence_id = sequence_id or self.spark._random_sequence_id
61        self.last_op = last_op
62        self.pending_hints = pending_hints or []
63        self.output_expression_container = output_expression_container or exp.Select()
def sql(self, dialect='spark', optimize=True, **kwargs) -> List[str]:
295    def sql(self, dialect="spark", optimize=True, **kwargs) -> t.List[str]:
296        df = self._resolve_pending_hints()
297        select_expressions = df._get_select_expressions()
298        output_expressions: t.List[t.Union[exp.Select, exp.Cache, exp.Drop]] = []
299        replacement_mapping: t.Dict[exp.Identifier, exp.Identifier] = {}
300        for expression_type, select_expression in select_expressions:
301            select_expression = select_expression.transform(replace_id_value, replacement_mapping)
302            if optimize:
303                select_expression = optimize_func(select_expression, identify="always")
304            select_expression = df._replace_cte_names_with_hashes(select_expression)
305            expression: t.Union[exp.Select, exp.Cache, exp.Drop]
306            if expression_type == exp.Cache:
307                cache_table_name = df._create_hash_from_expression(select_expression)
308                cache_table = exp.to_table(cache_table_name)
309                original_alias_name = select_expression.args["cte_alias_name"]
310
311                replacement_mapping[exp.to_identifier(original_alias_name)] = exp.to_identifier(  # type: ignore
312                    cache_table_name
313                )
314                sqlglot.schema.add_table(
315                    cache_table_name,
316                    {
317                        expression.alias_or_name: expression.type.sql("spark")
318                        for expression in select_expression.expressions
319                    },
320                )
321                cache_storage_level = select_expression.args["cache_storage_level"]
322                options = [
323                    exp.Literal.string("storageLevel"),
324                    exp.Literal.string(cache_storage_level),
325                ]
326                expression = exp.Cache(
327                    this=cache_table, expression=select_expression, lazy=True, options=options
328                )
329                # We will drop the "view" if it exists before running the cache table
330                output_expressions.append(exp.Drop(this=cache_table, exists=True, kind="VIEW"))
331            elif expression_type == exp.Create:
332                expression = df.output_expression_container.copy()
333                expression.set("expression", select_expression)
334            elif expression_type == exp.Insert:
335                expression = df.output_expression_container.copy()
336                select_without_ctes = select_expression.copy()
337                select_without_ctes.set("with", None)
338                expression.set("expression", select_without_ctes)
339                if select_expression.ctes:
340                    expression.set("with", exp.With(expressions=select_expression.ctes))
341            elif expression_type == exp.Select:
342                expression = select_expression
343            else:
344                raise ValueError(f"Invalid expression type: {expression_type}")
345            output_expressions.append(expression)
346
347        return [
348            expression.sql(**{"dialect": dialect, **kwargs}) for expression in output_expressions
349        ]
def copy(self, **kwargs) -> sqlglot.dataframe.sql.DataFrame:
351    def copy(self, **kwargs) -> DataFrame:
352        return DataFrame(**object_to_dict(self, **kwargs))
@operation(Operation.SELECT)
def select(self, *cols, **kwargs) -> sqlglot.dataframe.sql.DataFrame:
354    @operation(Operation.SELECT)
355    def select(self, *cols, **kwargs) -> DataFrame:
356        cols = self._ensure_and_normalize_cols(cols)
357        kwargs["append"] = kwargs.get("append", False)
358        if self.expression.args.get("joins"):
359            ambiguous_cols = [col for col in cols if not col.column_expression.table]
360            if ambiguous_cols:
361                join_table_identifiers = [
362                    x.this for x in get_tables_from_expression_with_join(self.expression)
363                ]
364                cte_names_in_join = [x.this for x in join_table_identifiers]
365                for ambiguous_col in ambiguous_cols:
366                    ctes_with_column = [
367                        cte
368                        for cte in self.expression.ctes
369                        if cte.alias_or_name in cte_names_in_join
370                        and ambiguous_col.alias_or_name in cte.this.named_selects
371                    ]
372                    # If the select column does not specify a table and there is a join
373                    # then we assume they are referring to the left table
374                    if len(ctes_with_column) > 1:
375                        table_identifier = self.expression.args["from"].args["expressions"][0].this
376                    else:
377                        table_identifier = ctes_with_column[0].args["alias"].this
378                    ambiguous_col.expression.set("table", table_identifier)
379        return self.copy(
380            expression=self.expression.select(*[x.expression for x in cols], **kwargs), **kwargs
381        )
@operation(Operation.NO_OP)
def alias(self, name: str, **kwargs) -> sqlglot.dataframe.sql.DataFrame:
383    @operation(Operation.NO_OP)
384    def alias(self, name: str, **kwargs) -> DataFrame:
385        new_sequence_id = self.spark._random_sequence_id
386        df = self.copy()
387        for join_hint in df.pending_join_hints:
388            for expression in join_hint.expressions:
389                if expression.alias_or_name == self.sequence_id:
390                    expression.set("this", Column.ensure_col(new_sequence_id).expression)
391        df.spark._add_alias_to_mapping(name, new_sequence_id)
392        return df._convert_leaf_to_cte(sequence_id=new_sequence_id)
@operation(Operation.WHERE)
def where( self, column: Union[sqlglot.dataframe.sql.Column, bool], **kwargs) -> sqlglot.dataframe.sql.DataFrame:
394    @operation(Operation.WHERE)
395    def where(self, column: t.Union[Column, bool], **kwargs) -> DataFrame:
396        col = self._ensure_and_normalize_col(column)
397        return self.copy(expression=self.expression.where(col.expression))
@operation(Operation.WHERE)
def filter( self, column: Union[sqlglot.dataframe.sql.Column, bool], **kwargs) -> sqlglot.dataframe.sql.DataFrame:
394    @operation(Operation.WHERE)
395    def where(self, column: t.Union[Column, bool], **kwargs) -> DataFrame:
396        col = self._ensure_and_normalize_col(column)
397        return self.copy(expression=self.expression.where(col.expression))
@operation(Operation.GROUP_BY)
def groupBy(self, *cols, **kwargs) -> sqlglot.dataframe.sql.GroupedData:
401    @operation(Operation.GROUP_BY)
402    def groupBy(self, *cols, **kwargs) -> GroupedData:
403        columns = self._ensure_and_normalize_cols(cols)
404        return GroupedData(self, columns, self.last_op)
@operation(Operation.SELECT)
def agg(self, *exprs, **kwargs) -> sqlglot.dataframe.sql.DataFrame:
406    @operation(Operation.SELECT)
407    def agg(self, *exprs, **kwargs) -> DataFrame:
408        cols = self._ensure_and_normalize_cols(exprs)
409        return self.groupBy().agg(*cols)
@operation(Operation.FROM)
def join( self, other_df: sqlglot.dataframe.sql.DataFrame, on: Union[str, List[str], sqlglot.dataframe.sql.Column, List[sqlglot.dataframe.sql.Column]], how: str = 'inner', **kwargs) -> sqlglot.dataframe.sql.DataFrame:
411    @operation(Operation.FROM)
412    def join(
413        self,
414        other_df: DataFrame,
415        on: t.Union[str, t.List[str], Column, t.List[Column]],
416        how: str = "inner",
417        **kwargs,
418    ) -> DataFrame:
419        other_df = other_df._convert_leaf_to_cte()
420        pre_join_self_latest_cte_name = self.latest_cte_name
421        columns = self._ensure_and_normalize_cols(on)
422        join_type = how.replace("_", " ")
423        if isinstance(columns[0].expression, exp.Column):
424            join_columns = [
425                Column(x).set_table_name(pre_join_self_latest_cte_name) for x in columns
426            ]
427            join_clause = functools.reduce(
428                lambda x, y: x & y,
429                [
430                    col.copy().set_table_name(pre_join_self_latest_cte_name)
431                    == col.copy().set_table_name(other_df.latest_cte_name)
432                    for col in columns
433                ],
434            )
435        else:
436            if len(columns) > 1:
437                columns = [functools.reduce(lambda x, y: x & y, columns)]
438            join_clause = columns[0]
439            join_columns = [
440                Column(x).set_table_name(pre_join_self_latest_cte_name)
441                if i % 2 == 0
442                else Column(x).set_table_name(other_df.latest_cte_name)
443                for i, x in enumerate(join_clause.expression.find_all(exp.Column))
444            ]
445        self_columns = [
446            column.set_table_name(pre_join_self_latest_cte_name, copy=True)
447            for column in self._get_outer_select_columns(self)
448        ]
449        other_columns = [
450            column.set_table_name(other_df.latest_cte_name, copy=True)
451            for column in self._get_outer_select_columns(other_df)
452        ]
453        column_value_mapping = {
454            column.alias_or_name
455            if not isinstance(column.expression.this, exp.Star)
456            else column.sql(): column
457            for column in other_columns + self_columns + join_columns
458        }
459        all_columns = [
460            column_value_mapping[name]
461            for name in {x.alias_or_name: None for x in join_columns + self_columns + other_columns}
462        ]
463        new_df = self.copy(
464            expression=self.expression.join(
465                other_df.latest_cte_name, on=join_clause.expression, join_type=join_type
466            )
467        )
468        new_df.expression = new_df._add_ctes_to_expression(
469            new_df.expression, other_df.expression.ctes
470        )
471        new_df.pending_hints.extend(other_df.pending_hints)
472        new_df = new_df.select.__wrapped__(new_df, *all_columns)
473        return new_df
@operation(Operation.ORDER_BY)
def orderBy( self, *cols: Union[str, sqlglot.dataframe.sql.Column], ascending: Union[Any, List[Any], NoneType] = None) -> sqlglot.dataframe.sql.DataFrame:
475    @operation(Operation.ORDER_BY)
476    def orderBy(
477        self,
478        *cols: t.Union[str, Column],
479        ascending: t.Optional[t.Union[t.Any, t.List[t.Any]]] = None,
480    ) -> DataFrame:
481        """
482        This implementation lets any ordered columns take priority over whatever is provided in `ascending`. Spark
483        has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this
484        is unlikely to come up.
485        """
486        columns = self._ensure_and_normalize_cols(cols)
487        pre_ordered_col_indexes = [
488            x
489            for x in [
490                i if isinstance(col.expression, exp.Ordered) else None
491                for i, col in enumerate(columns)
492            ]
493            if x is not None
494        ]
495        if ascending is None:
496            ascending = [True] * len(columns)
497        elif not isinstance(ascending, list):
498            ascending = [ascending] * len(columns)
499        ascending = [bool(x) for i, x in enumerate(ascending)]
500        assert len(columns) == len(
501            ascending
502        ), "The length of items in ascending must equal the number of columns provided"
503        col_and_ascending = list(zip(columns, ascending))
504        order_by_columns = [
505            exp.Ordered(this=col.expression, desc=not asc)
506            if i not in pre_ordered_col_indexes
507            else columns[i].column_expression
508            for i, (col, asc) in enumerate(col_and_ascending)
509        ]
510        return self.copy(expression=self.expression.order_by(*order_by_columns))

This implementation lets any ordered columns take priority over whatever is provided in ascending. Spark has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this is unlikely to come up.

@operation(Operation.ORDER_BY)
def sort( self, *cols: Union[str, sqlglot.dataframe.sql.Column], ascending: Union[Any, List[Any], NoneType] = None) -> sqlglot.dataframe.sql.DataFrame:
475    @operation(Operation.ORDER_BY)
476    def orderBy(
477        self,
478        *cols: t.Union[str, Column],
479        ascending: t.Optional[t.Union[t.Any, t.List[t.Any]]] = None,
480    ) -> DataFrame:
481        """
482        This implementation lets any ordered columns take priority over whatever is provided in `ascending`. Spark
483        has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this
484        is unlikely to come up.
485        """
486        columns = self._ensure_and_normalize_cols(cols)
487        pre_ordered_col_indexes = [
488            x
489            for x in [
490                i if isinstance(col.expression, exp.Ordered) else None
491                for i, col in enumerate(columns)
492            ]
493            if x is not None
494        ]
495        if ascending is None:
496            ascending = [True] * len(columns)
497        elif not isinstance(ascending, list):
498            ascending = [ascending] * len(columns)
499        ascending = [bool(x) for i, x in enumerate(ascending)]
500        assert len(columns) == len(
501            ascending
502        ), "The length of items in ascending must equal the number of columns provided"
503        col_and_ascending = list(zip(columns, ascending))
504        order_by_columns = [
505            exp.Ordered(this=col.expression, desc=not asc)
506            if i not in pre_ordered_col_indexes
507            else columns[i].column_expression
508            for i, (col, asc) in enumerate(col_and_ascending)
509        ]
510        return self.copy(expression=self.expression.order_by(*order_by_columns))

This implementation lets any ordered columns take priority over whatever is provided in ascending. Spark has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this is unlikely to come up.

@operation(Operation.FROM)
def union( self, other: sqlglot.dataframe.sql.DataFrame) -> sqlglot.dataframe.sql.DataFrame:
514    @operation(Operation.FROM)
515    def union(self, other: DataFrame) -> DataFrame:
516        return self._set_operation(exp.Union, other, False)
@operation(Operation.FROM)
def unionAll( self, other: sqlglot.dataframe.sql.DataFrame) -> sqlglot.dataframe.sql.DataFrame:
514    @operation(Operation.FROM)
515    def union(self, other: DataFrame) -> DataFrame:
516        return self._set_operation(exp.Union, other, False)
@operation(Operation.FROM)
def unionByName( self, other: sqlglot.dataframe.sql.DataFrame, allowMissingColumns: bool = False):
520    @operation(Operation.FROM)
521    def unionByName(self, other: DataFrame, allowMissingColumns: bool = False):
522        l_columns = self.columns
523        r_columns = other.columns
524        if not allowMissingColumns:
525            l_expressions = l_columns
526            r_expressions = l_columns
527        else:
528            l_expressions = []
529            r_expressions = []
530            r_columns_unused = copy(r_columns)
531            for l_column in l_columns:
532                l_expressions.append(l_column)
533                if l_column in r_columns:
534                    r_expressions.append(l_column)
535                    r_columns_unused.remove(l_column)
536                else:
537                    r_expressions.append(exp.alias_(exp.Null(), l_column))
538            for r_column in r_columns_unused:
539                l_expressions.append(exp.alias_(exp.Null(), r_column))
540                r_expressions.append(r_column)
541        r_df = (
542            other.copy()._convert_leaf_to_cte().select(*self._ensure_list_of_columns(r_expressions))
543        )
544        l_df = self.copy()
545        if allowMissingColumns:
546            l_df = l_df._convert_leaf_to_cte().select(*self._ensure_list_of_columns(l_expressions))
547        return l_df._set_operation(exp.Union, r_df, False)
@operation(Operation.FROM)
def intersect( self, other: sqlglot.dataframe.sql.DataFrame) -> sqlglot.dataframe.sql.DataFrame:
549    @operation(Operation.FROM)
550    def intersect(self, other: DataFrame) -> DataFrame:
551        return self._set_operation(exp.Intersect, other, True)
@operation(Operation.FROM)
def intersectAll( self, other: sqlglot.dataframe.sql.DataFrame) -> sqlglot.dataframe.sql.DataFrame:
553    @operation(Operation.FROM)
554    def intersectAll(self, other: DataFrame) -> DataFrame:
555        return self._set_operation(exp.Intersect, other, False)
@operation(Operation.FROM)
def exceptAll( self, other: sqlglot.dataframe.sql.DataFrame) -> sqlglot.dataframe.sql.DataFrame:
557    @operation(Operation.FROM)
558    def exceptAll(self, other: DataFrame) -> DataFrame:
559        return self._set_operation(exp.Except, other, False)
@operation(Operation.SELECT)
def distinct(self) -> sqlglot.dataframe.sql.DataFrame:
561    @operation(Operation.SELECT)
562    def distinct(self) -> DataFrame:
563        return self.copy(expression=self.expression.distinct())
@operation(Operation.SELECT)
def dropDuplicates(self, subset: Optional[List[str]] = None):
565    @operation(Operation.SELECT)
566    def dropDuplicates(self, subset: t.Optional[t.List[str]] = None):
567        if not subset:
568            return self.distinct()
569        column_names = ensure_list(subset)
570        window = Window.partitionBy(*column_names).orderBy(*column_names)
571        return (
572            self.copy()
573            .withColumn("row_num", F.row_number().over(window))
574            .where(F.col("row_num") == F.lit(1))
575            .drop("row_num")
576        )
@operation(Operation.FROM)
def dropna( self, how: str = 'any', thresh: Optional[int] = None, subset: Union[str, Tuple[str, ...], List[str], NoneType] = None) -> sqlglot.dataframe.sql.DataFrame:
578    @operation(Operation.FROM)
579    def dropna(
580        self,
581        how: str = "any",
582        thresh: t.Optional[int] = None,
583        subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
584    ) -> DataFrame:
585        minimum_non_null = thresh or 0  # will be determined later if thresh is null
586        new_df = self.copy()
587        all_columns = self._get_outer_select_columns(new_df.expression)
588        if subset:
589            null_check_columns = self._ensure_and_normalize_cols(subset)
590        else:
591            null_check_columns = all_columns
592        if thresh is None:
593            minimum_num_nulls = 1 if how == "any" else len(null_check_columns)
594        else:
595            minimum_num_nulls = len(null_check_columns) - minimum_non_null + 1
596        if minimum_num_nulls > len(null_check_columns):
597            raise RuntimeError(
598                f"The minimum num nulls for dropna must be less than or equal to the number of columns. "
599                f"Minimum num nulls: {minimum_num_nulls}, Num Columns: {len(null_check_columns)}"
600            )
601        if_null_checks = [
602            F.when(column.isNull(), F.lit(1)).otherwise(F.lit(0)) for column in null_check_columns
603        ]
604        nulls_added_together = functools.reduce(lambda x, y: x + y, if_null_checks)
605        num_nulls = nulls_added_together.alias("num_nulls")
606        new_df = new_df.select(num_nulls, append=True)
607        filtered_df = new_df.where(F.col("num_nulls") < F.lit(minimum_num_nulls))
608        final_df = filtered_df.select(*all_columns)
609        return final_df
@operation(Operation.FROM)
def fillna( self, value: <MagicMock id='139719847202288'>, subset: Union[str, Tuple[str, ...], List[str], NoneType] = None) -> sqlglot.dataframe.sql.DataFrame:
611    @operation(Operation.FROM)
612    def fillna(
613        self,
614        value: t.Union[ColumnLiterals],
615        subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
616    ) -> DataFrame:
617        """
618        Functionality Difference: If you provide a value to replace a null and that type conflicts
619        with the type of the column then PySpark will just ignore your replacement.
620        This will try to cast them to be the same in some cases. So they won't always match.
621        Best to not mix types so make sure replacement is the same type as the column
622
623        Possibility for improvement: Use `typeof` function to get the type of the column
624        and check if it matches the type of the value provided. If not then make it null.
625        """
626        from sqlglot.dataframe.sql.functions import lit
627
628        values = None
629        columns = None
630        new_df = self.copy()
631        all_columns = self._get_outer_select_columns(new_df.expression)
632        all_column_mapping = {column.alias_or_name: column for column in all_columns}
633        if isinstance(value, dict):
634            values = list(value.values())
635            columns = self._ensure_and_normalize_cols(list(value))
636        if not columns:
637            columns = self._ensure_and_normalize_cols(subset) if subset else all_columns
638        if not values:
639            values = [value] * len(columns)
640        value_columns = [lit(value) for value in values]
641
642        null_replacement_mapping = {
643            column.alias_or_name: (
644                F.when(column.isNull(), value).otherwise(column).alias(column.alias_or_name)
645            )
646            for column, value in zip(columns, value_columns)
647        }
648        null_replacement_mapping = {**all_column_mapping, **null_replacement_mapping}
649        null_replacement_columns = [
650            null_replacement_mapping[column.alias_or_name] for column in all_columns
651        ]
652        new_df = new_df.select(*null_replacement_columns)
653        return new_df

Functionality Difference: If you provide a value to replace a null and that type conflicts with the type of the column then PySpark will just ignore your replacement. This will try to cast them to be the same in some cases. So they won't always match. Best to not mix types so make sure replacement is the same type as the column

Possibility for improvement: Use typeof function to get the type of the column and check if it matches the type of the value provided. If not then make it null.

@operation(Operation.FROM)
def replace( self, to_replace: Union[bool, int, float, str, List, Dict], value: Union[bool, int, float, str, List, NoneType] = None, subset: Union[Collection[<MagicMock id='139719846105664'>], <MagicMock id='139719846105664'>, NoneType] = None) -> sqlglot.dataframe.sql.DataFrame:
655    @operation(Operation.FROM)
656    def replace(
657        self,
658        to_replace: t.Union[bool, int, float, str, t.List, t.Dict],
659        value: t.Optional[t.Union[bool, int, float, str, t.List]] = None,
660        subset: t.Optional[t.Collection[ColumnOrName] | ColumnOrName] = None,
661    ) -> DataFrame:
662        from sqlglot.dataframe.sql.functions import lit
663
664        old_values = None
665        new_df = self.copy()
666        all_columns = self._get_outer_select_columns(new_df.expression)
667        all_column_mapping = {column.alias_or_name: column for column in all_columns}
668
669        columns = self._ensure_and_normalize_cols(subset) if subset else all_columns
670        if isinstance(to_replace, dict):
671            old_values = list(to_replace)
672            new_values = list(to_replace.values())
673        elif not old_values and isinstance(to_replace, list):
674            assert isinstance(value, list), "value must be a list since the replacements are a list"
675            assert len(to_replace) == len(
676                value
677            ), "the replacements and values must be the same length"
678            old_values = to_replace
679            new_values = value
680        else:
681            old_values = [to_replace] * len(columns)
682            new_values = [value] * len(columns)
683        old_values = [lit(value) for value in old_values]
684        new_values = [lit(value) for value in new_values]
685
686        replacement_mapping = {}
687        for column in columns:
688            expression = Column(None)
689            for i, (old_value, new_value) in enumerate(zip(old_values, new_values)):
690                if i == 0:
691                    expression = F.when(column == old_value, new_value)
692                else:
693                    expression = expression.when(column == old_value, new_value)  # type: ignore
694            replacement_mapping[column.alias_or_name] = expression.otherwise(column).alias(
695                column.expression.alias_or_name
696            )
697
698        replacement_mapping = {**all_column_mapping, **replacement_mapping}
699        replacement_columns = [replacement_mapping[column.alias_or_name] for column in all_columns]
700        new_df = new_df.select(*replacement_columns)
701        return new_df
@operation(Operation.SELECT)
def withColumn( self, colName: str, col: sqlglot.dataframe.sql.Column) -> sqlglot.dataframe.sql.DataFrame:
703    @operation(Operation.SELECT)
704    def withColumn(self, colName: str, col: Column) -> DataFrame:
705        col = self._ensure_and_normalize_col(col)
706        existing_col_names = self.expression.named_selects
707        existing_col_index = (
708            existing_col_names.index(colName) if colName in existing_col_names else None
709        )
710        if existing_col_index:
711            expression = self.expression.copy()
712            expression.expressions[existing_col_index] = col.expression
713            return self.copy(expression=expression)
714        return self.copy().select(col.alias(colName), append=True)
@operation(Operation.SELECT)
def withColumnRenamed(self, existing: str, new: str):
716    @operation(Operation.SELECT)
717    def withColumnRenamed(self, existing: str, new: str):
718        expression = self.expression.copy()
719        existing_columns = [
720            expression
721            for expression in expression.expressions
722            if expression.alias_or_name == existing
723        ]
724        if not existing_columns:
725            raise ValueError("Tried to rename a column that doesn't exist")
726        for existing_column in existing_columns:
727            if isinstance(existing_column, exp.Column):
728                existing_column.replace(exp.alias_(existing_column.copy(), new))
729            else:
730                existing_column.set("alias", exp.to_identifier(new))
731        return self.copy(expression=expression)
@operation(Operation.SELECT)
def drop( self, *cols: Union[str, sqlglot.dataframe.sql.Column]) -> sqlglot.dataframe.sql.DataFrame:
733    @operation(Operation.SELECT)
734    def drop(self, *cols: t.Union[str, Column]) -> DataFrame:
735        all_columns = self._get_outer_select_columns(self.expression)
736        drop_cols = self._ensure_and_normalize_cols(cols)
737        new_columns = [
738            col
739            for col in all_columns
740            if col.alias_or_name not in [drop_column.alias_or_name for drop_column in drop_cols]
741        ]
742        return self.copy().select(*new_columns, append=False)
@operation(Operation.LIMIT)
def limit(self, num: int) -> sqlglot.dataframe.sql.DataFrame:
744    @operation(Operation.LIMIT)
745    def limit(self, num: int) -> DataFrame:
746        return self.copy(expression=self.expression.limit(num))
@operation(Operation.NO_OP)
def hint( self, name: str, *parameters: Union[str, int, NoneType]) -> sqlglot.dataframe.sql.DataFrame:
748    @operation(Operation.NO_OP)
749    def hint(self, name: str, *parameters: t.Optional[t.Union[str, int]]) -> DataFrame:
750        parameter_list = ensure_list(parameters)
751        parameter_columns = (
752            self._ensure_list_of_columns(parameter_list)
753            if parameters
754            else Column.ensure_cols([self.sequence_id])
755        )
756        return self._hint(name, parameter_columns)
@operation(Operation.NO_OP)
def repartition( self, numPartitions: Union[int, <MagicMock id='139719845685008'>], *cols: <MagicMock id='139719845795232'>) -> sqlglot.dataframe.sql.DataFrame:
758    @operation(Operation.NO_OP)
759    def repartition(
760        self, numPartitions: t.Union[int, ColumnOrName], *cols: ColumnOrName
761    ) -> DataFrame:
762        num_partition_cols = self._ensure_list_of_columns(numPartitions)
763        columns = self._ensure_and_normalize_cols(cols)
764        args = num_partition_cols + columns
765        return self._hint("repartition", args)
@operation(Operation.NO_OP)
def coalesce(self, numPartitions: int) -> sqlglot.dataframe.sql.DataFrame:
767    @operation(Operation.NO_OP)
768    def coalesce(self, numPartitions: int) -> DataFrame:
769        num_partitions = Column.ensure_cols([numPartitions])
770        return self._hint("coalesce", num_partitions)
@operation(Operation.NO_OP)
def cache(self) -> sqlglot.dataframe.sql.DataFrame:
772    @operation(Operation.NO_OP)
773    def cache(self) -> DataFrame:
774        return self._cache(storage_level="MEMORY_AND_DISK")
@operation(Operation.NO_OP)
def persist( self, storageLevel: str = 'MEMORY_AND_DISK_SER') -> sqlglot.dataframe.sql.DataFrame:
776    @operation(Operation.NO_OP)
777    def persist(self, storageLevel: str = "MEMORY_AND_DISK_SER") -> DataFrame:
778        """
779        Storage Level Options: https://spark.apache.org/docs/3.0.0-preview/sql-ref-syntax-aux-cache-cache-table.html
780        """
781        return self._cache(storageLevel)
class GroupedData:
14class GroupedData:
15    def __init__(self, df: DataFrame, group_by_cols: t.List[Column], last_op: Operation):
16        self._df = df.copy()
17        self.spark = df.spark
18        self.last_op = last_op
19        self.group_by_cols = group_by_cols
20
21    def _get_function_applied_columns(
22        self, func_name: str, cols: t.Tuple[str, ...]
23    ) -> t.List[Column]:
24        func_name = func_name.lower()
25        return [getattr(F, func_name)(name).alias(f"{func_name}({name})") for name in cols]
26
27    @operation(Operation.SELECT)
28    def agg(self, *exprs: t.Union[Column, t.Dict[str, str]]) -> DataFrame:
29        columns = (
30            [Column(f"{agg_func}({column_name})") for column_name, agg_func in exprs[0].items()]
31            if isinstance(exprs[0], dict)
32            else exprs
33        )
34        cols = self._df._ensure_and_normalize_cols(columns)
35
36        expression = self._df.expression.group_by(
37            *[x.expression for x in self.group_by_cols]
38        ).select(*[x.expression for x in self.group_by_cols + cols], append=False)
39        return self._df.copy(expression=expression)
40
41    def count(self) -> DataFrame:
42        return self.agg(F.count("*").alias("count"))
43
44    def mean(self, *cols: str) -> DataFrame:
45        return self.avg(*cols)
46
47    def avg(self, *cols: str) -> DataFrame:
48        return self.agg(*self._get_function_applied_columns("avg", cols))
49
50    def max(self, *cols: str) -> DataFrame:
51        return self.agg(*self._get_function_applied_columns("max", cols))
52
53    def min(self, *cols: str) -> DataFrame:
54        return self.agg(*self._get_function_applied_columns("min", cols))
55
56    def sum(self, *cols: str) -> DataFrame:
57        return self.agg(*self._get_function_applied_columns("sum", cols))
58
59    def pivot(self, *cols: str) -> DataFrame:
60        raise NotImplementedError("Sum distinct is not currently implemented")
GroupedData( df: sqlglot.dataframe.sql.DataFrame, group_by_cols: List[sqlglot.dataframe.sql.Column], last_op: sqlglot.dataframe.sql.operations.Operation)
15    def __init__(self, df: DataFrame, group_by_cols: t.List[Column], last_op: Operation):
16        self._df = df.copy()
17        self.spark = df.spark
18        self.last_op = last_op
19        self.group_by_cols = group_by_cols
@operation(Operation.SELECT)
def agg( self, *exprs: Union[sqlglot.dataframe.sql.Column, Dict[str, str]]) -> sqlglot.dataframe.sql.DataFrame:
27    @operation(Operation.SELECT)
28    def agg(self, *exprs: t.Union[Column, t.Dict[str, str]]) -> DataFrame:
29        columns = (
30            [Column(f"{agg_func}({column_name})") for column_name, agg_func in exprs[0].items()]
31            if isinstance(exprs[0], dict)
32            else exprs
33        )
34        cols = self._df._ensure_and_normalize_cols(columns)
35
36        expression = self._df.expression.group_by(
37            *[x.expression for x in self.group_by_cols]
38        ).select(*[x.expression for x in self.group_by_cols + cols], append=False)
39        return self._df.copy(expression=expression)
def count(self) -> sqlglot.dataframe.sql.DataFrame:
41    def count(self) -> DataFrame:
42        return self.agg(F.count("*").alias("count"))
def mean(self, *cols: str) -> sqlglot.dataframe.sql.DataFrame:
44    def mean(self, *cols: str) -> DataFrame:
45        return self.avg(*cols)
def avg(self, *cols: str) -> sqlglot.dataframe.sql.DataFrame:
47    def avg(self, *cols: str) -> DataFrame:
48        return self.agg(*self._get_function_applied_columns("avg", cols))
def max(self, *cols: str) -> sqlglot.dataframe.sql.DataFrame:
50    def max(self, *cols: str) -> DataFrame:
51        return self.agg(*self._get_function_applied_columns("max", cols))
def min(self, *cols: str) -> sqlglot.dataframe.sql.DataFrame:
53    def min(self, *cols: str) -> DataFrame:
54        return self.agg(*self._get_function_applied_columns("min", cols))
def sum(self, *cols: str) -> sqlglot.dataframe.sql.DataFrame:
56    def sum(self, *cols: str) -> DataFrame:
57        return self.agg(*self._get_function_applied_columns("sum", cols))
def pivot(self, *cols: str) -> sqlglot.dataframe.sql.DataFrame:
59    def pivot(self, *cols: str) -> DataFrame:
60        raise NotImplementedError("Sum distinct is not currently implemented")
class Column:
 16class Column:
 17    def __init__(self, expression: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]):
 18        if isinstance(expression, Column):
 19            expression = expression.expression  # type: ignore
 20        elif expression is None or not isinstance(expression, (str, exp.Expression)):
 21            expression = self._lit(expression).expression  # type: ignore
 22
 23        expression = sqlglot.maybe_parse(expression, dialect="spark")
 24        if expression is None:
 25            raise ValueError(f"Could not parse {expression}")
 26        self.expression: exp.Expression = expression
 27
 28    def __repr__(self):
 29        return repr(self.expression)
 30
 31    def __hash__(self):
 32        return hash(self.expression)
 33
 34    def __eq__(self, other: ColumnOrLiteral) -> Column:  # type: ignore
 35        return self.binary_op(exp.EQ, other)
 36
 37    def __ne__(self, other: ColumnOrLiteral) -> Column:  # type: ignore
 38        return self.binary_op(exp.NEQ, other)
 39
 40    def __gt__(self, other: ColumnOrLiteral) -> Column:
 41        return self.binary_op(exp.GT, other)
 42
 43    def __ge__(self, other: ColumnOrLiteral) -> Column:
 44        return self.binary_op(exp.GTE, other)
 45
 46    def __lt__(self, other: ColumnOrLiteral) -> Column:
 47        return self.binary_op(exp.LT, other)
 48
 49    def __le__(self, other: ColumnOrLiteral) -> Column:
 50        return self.binary_op(exp.LTE, other)
 51
 52    def __and__(self, other: ColumnOrLiteral) -> Column:
 53        return self.binary_op(exp.And, other)
 54
 55    def __or__(self, other: ColumnOrLiteral) -> Column:
 56        return self.binary_op(exp.Or, other)
 57
 58    def __mod__(self, other: ColumnOrLiteral) -> Column:
 59        return self.binary_op(exp.Mod, other)
 60
 61    def __add__(self, other: ColumnOrLiteral) -> Column:
 62        return self.binary_op(exp.Add, other)
 63
 64    def __sub__(self, other: ColumnOrLiteral) -> Column:
 65        return self.binary_op(exp.Sub, other)
 66
 67    def __mul__(self, other: ColumnOrLiteral) -> Column:
 68        return self.binary_op(exp.Mul, other)
 69
 70    def __truediv__(self, other: ColumnOrLiteral) -> Column:
 71        return self.binary_op(exp.Div, other)
 72
 73    def __div__(self, other: ColumnOrLiteral) -> Column:
 74        return self.binary_op(exp.Div, other)
 75
 76    def __neg__(self) -> Column:
 77        return self.unary_op(exp.Neg)
 78
 79    def __radd__(self, other: ColumnOrLiteral) -> Column:
 80        return self.inverse_binary_op(exp.Add, other)
 81
 82    def __rsub__(self, other: ColumnOrLiteral) -> Column:
 83        return self.inverse_binary_op(exp.Sub, other)
 84
 85    def __rmul__(self, other: ColumnOrLiteral) -> Column:
 86        return self.inverse_binary_op(exp.Mul, other)
 87
 88    def __rdiv__(self, other: ColumnOrLiteral) -> Column:
 89        return self.inverse_binary_op(exp.Div, other)
 90
 91    def __rtruediv__(self, other: ColumnOrLiteral) -> Column:
 92        return self.inverse_binary_op(exp.Div, other)
 93
 94    def __rmod__(self, other: ColumnOrLiteral) -> Column:
 95        return self.inverse_binary_op(exp.Mod, other)
 96
 97    def __pow__(self, power: ColumnOrLiteral, modulo=None):
 98        return Column(exp.Pow(this=self.expression, expression=Column(power).expression))
 99
100    def __rpow__(self, power: ColumnOrLiteral):
101        return Column(exp.Pow(this=Column(power).expression, expression=self.expression))
102
103    def __invert__(self):
104        return self.unary_op(exp.Not)
105
106    def __rand__(self, other: ColumnOrLiteral) -> Column:
107        return self.inverse_binary_op(exp.And, other)
108
109    def __ror__(self, other: ColumnOrLiteral) -> Column:
110        return self.inverse_binary_op(exp.Or, other)
111
112    @classmethod
113    def ensure_col(cls, value: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]):
114        return cls(value)
115
116    @classmethod
117    def ensure_cols(cls, args: t.List[t.Union[ColumnOrLiteral, exp.Expression]]) -> t.List[Column]:
118        return [cls.ensure_col(x) if not isinstance(x, Column) else x for x in args]
119
120    @classmethod
121    def _lit(cls, value: ColumnOrLiteral) -> Column:
122        if isinstance(value, dict):
123            columns = [cls._lit(v).alias(k).expression for k, v in value.items()]
124            return cls(exp.Struct(expressions=columns))
125        return cls(exp.convert(value))
126
127    @classmethod
128    def invoke_anonymous_function(
129        cls, column: t.Optional[ColumnOrLiteral], func_name: str, *args: t.Optional[ColumnOrLiteral]
130    ) -> Column:
131        columns = [] if column is None else [cls.ensure_col(column)]
132        column_args = [cls.ensure_col(arg) for arg in args]
133        expressions = [x.expression for x in columns + column_args]
134        new_expression = exp.Anonymous(this=func_name.upper(), expressions=expressions)
135        return Column(new_expression)
136
137    @classmethod
138    def invoke_expression_over_column(
139        cls, column: t.Optional[ColumnOrLiteral], callable_expression: t.Callable, **kwargs
140    ) -> Column:
141        ensured_column = None if column is None else cls.ensure_col(column)
142        ensure_expression_values = {
143            k: [Column.ensure_col(x).expression for x in v]
144            if is_iterable(v)
145            else Column.ensure_col(v).expression
146            for k, v in kwargs.items()
147            if v is not None
148        }
149        new_expression = (
150            callable_expression(**ensure_expression_values)
151            if ensured_column is None
152            else callable_expression(
153                this=ensured_column.column_expression, **ensure_expression_values
154            )
155        )
156        return Column(new_expression)
157
158    def binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column:
159        return Column(
160            klass(this=self.column_expression, expression=Column(other).column_expression, **kwargs)
161        )
162
163    def inverse_binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column:
164        return Column(
165            klass(this=Column(other).column_expression, expression=self.column_expression, **kwargs)
166        )
167
168    def unary_op(self, klass: t.Callable, **kwargs) -> Column:
169        return Column(klass(this=self.column_expression, **kwargs))
170
171    @property
172    def is_alias(self):
173        return isinstance(self.expression, exp.Alias)
174
175    @property
176    def is_column(self):
177        return isinstance(self.expression, exp.Column)
178
179    @property
180    def column_expression(self) -> exp.Column:
181        return self.expression.unalias()
182
183    @property
184    def alias_or_name(self) -> str:
185        return self.expression.alias_or_name
186
187    @classmethod
188    def ensure_literal(cls, value) -> Column:
189        from sqlglot.dataframe.sql.functions import lit
190
191        if isinstance(value, cls):
192            value = value.expression
193        if not isinstance(value, exp.Literal):
194            return lit(value)
195        return Column(value)
196
197    def copy(self) -> Column:
198        return Column(self.expression.copy())
199
200    def set_table_name(self, table_name: str, copy=False) -> Column:
201        expression = self.expression.copy() if copy else self.expression
202        expression.set("table", exp.to_identifier(table_name))
203        return Column(expression)
204
205    def sql(self, **kwargs) -> str:
206        return self.expression.sql(**{"dialect": "spark", **kwargs})
207
208    def alias(self, name: str) -> Column:
209        new_expression = exp.alias_(self.column_expression, name)
210        return Column(new_expression)
211
212    def asc(self) -> Column:
213        new_expression = exp.Ordered(this=self.column_expression, desc=False, nulls_first=True)
214        return Column(new_expression)
215
216    def desc(self) -> Column:
217        new_expression = exp.Ordered(this=self.column_expression, desc=True, nulls_first=False)
218        return Column(new_expression)
219
220    asc_nulls_first = asc
221
222    def asc_nulls_last(self) -> Column:
223        new_expression = exp.Ordered(this=self.column_expression, desc=False, nulls_first=False)
224        return Column(new_expression)
225
226    def desc_nulls_first(self) -> Column:
227        new_expression = exp.Ordered(this=self.column_expression, desc=True, nulls_first=True)
228        return Column(new_expression)
229
230    desc_nulls_last = desc
231
232    def when(self, condition: Column, value: t.Any) -> Column:
233        from sqlglot.dataframe.sql.functions import when
234
235        column_with_if = when(condition, value)
236        if not isinstance(self.expression, exp.Case):
237            return column_with_if
238        new_column = self.copy()
239        new_column.expression.args["ifs"].extend(column_with_if.expression.args["ifs"])
240        return new_column
241
242    def otherwise(self, value: t.Any) -> Column:
243        from sqlglot.dataframe.sql.functions import lit
244
245        true_value = value if isinstance(value, Column) else lit(value)
246        new_column = self.copy()
247        new_column.expression.set("default", true_value.column_expression)
248        return new_column
249
250    def isNull(self) -> Column:
251        new_expression = exp.Is(this=self.column_expression, expression=exp.Null())
252        return Column(new_expression)
253
254    def isNotNull(self) -> Column:
255        new_expression = exp.Not(this=exp.Is(this=self.column_expression, expression=exp.Null()))
256        return Column(new_expression)
257
258    def cast(self, dataType: t.Union[str, DataType]):
259        """
260        Functionality Difference: PySpark cast accepts a datatype instance of the datatype class
261        Sqlglot doesn't currently replicate this class so it only accepts a string
262        """
263        if isinstance(dataType, DataType):
264            dataType = dataType.simpleString()
265        return Column(exp.cast(self.column_expression, dataType, dialect="spark"))
266
267    def startswith(self, value: t.Union[str, Column]) -> Column:
268        value = self._lit(value) if not isinstance(value, Column) else value
269        return self.invoke_anonymous_function(self, "STARTSWITH", value)
270
271    def endswith(self, value: t.Union[str, Column]) -> Column:
272        value = self._lit(value) if not isinstance(value, Column) else value
273        return self.invoke_anonymous_function(self, "ENDSWITH", value)
274
275    def rlike(self, regexp: str) -> Column:
276        return self.invoke_expression_over_column(
277            column=self, callable_expression=exp.RegexpLike, expression=self._lit(regexp).expression
278        )
279
280    def like(self, other: str):
281        return self.invoke_expression_over_column(
282            self, exp.Like, expression=self._lit(other).expression
283        )
284
285    def ilike(self, other: str):
286        return self.invoke_expression_over_column(
287            self, exp.ILike, expression=self._lit(other).expression
288        )
289
290    def substr(self, startPos: t.Union[int, Column], length: t.Union[int, Column]) -> Column:
291        startPos = self._lit(startPos) if not isinstance(startPos, Column) else startPos
292        length = self._lit(length) if not isinstance(length, Column) else length
293        return Column.invoke_expression_over_column(
294            self, exp.Substring, start=startPos.expression, length=length.expression
295        )
296
297    def isin(self, *cols: t.Union[ColumnOrLiteral, t.Iterable[ColumnOrLiteral]]):
298        columns = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols  # type: ignore
299        expressions = [self._lit(x).expression for x in columns]
300        return Column.invoke_expression_over_column(self, exp.In, expressions=expressions)  # type: ignore
301
302    def between(
303        self,
304        lowerBound: t.Union[ColumnOrLiteral],
305        upperBound: t.Union[ColumnOrLiteral],
306    ) -> Column:
307        lower_bound_exp = (
308            self._lit(lowerBound) if not isinstance(lowerBound, Column) else lowerBound
309        )
310        upper_bound_exp = (
311            self._lit(upperBound) if not isinstance(upperBound, Column) else upperBound
312        )
313        return Column(
314            exp.Between(
315                this=self.column_expression,
316                low=lower_bound_exp.expression,
317                high=upper_bound_exp.expression,
318            )
319        )
320
321    def over(self, window: WindowSpec) -> Column:
322        window_expression = window.expression.copy()
323        window_expression.set("this", self.column_expression)
324        return Column(window_expression)
Column( expression: Union[<MagicMock id='139719847713120'>, sqlglot.expressions.Expression, NoneType])
17    def __init__(self, expression: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]):
18        if isinstance(expression, Column):
19            expression = expression.expression  # type: ignore
20        elif expression is None or not isinstance(expression, (str, exp.Expression)):
21            expression = self._lit(expression).expression  # type: ignore
22
23        expression = sqlglot.maybe_parse(expression, dialect="spark")
24        if expression is None:
25            raise ValueError(f"Could not parse {expression}")
26        self.expression: exp.Expression = expression
@classmethod
def ensure_col( cls, value: Union[<MagicMock id='139719845612224'>, sqlglot.expressions.Expression, NoneType]):
112    @classmethod
113    def ensure_col(cls, value: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]):
114        return cls(value)
@classmethod
def ensure_cols( cls, args: List[Union[<MagicMock id='139719846275024'>, sqlglot.expressions.Expression]]) -> List[sqlglot.dataframe.sql.Column]:
116    @classmethod
117    def ensure_cols(cls, args: t.List[t.Union[ColumnOrLiteral, exp.Expression]]) -> t.List[Column]:
118        return [cls.ensure_col(x) if not isinstance(x, Column) else x for x in args]
@classmethod
def invoke_anonymous_function( cls, column: Optional[<MagicMock id='139719846030752'>], func_name: str, *args: Optional[<MagicMock id='139719845950944'>]) -> sqlglot.dataframe.sql.Column:
127    @classmethod
128    def invoke_anonymous_function(
129        cls, column: t.Optional[ColumnOrLiteral], func_name: str, *args: t.Optional[ColumnOrLiteral]
130    ) -> Column:
131        columns = [] if column is None else [cls.ensure_col(column)]
132        column_args = [cls.ensure_col(arg) for arg in args]
133        expressions = [x.expression for x in columns + column_args]
134        new_expression = exp.Anonymous(this=func_name.upper(), expressions=expressions)
135        return Column(new_expression)
@classmethod
def invoke_expression_over_column( cls, column: Optional[<MagicMock id='139719845906896'>], callable_expression: Callable, **kwargs) -> sqlglot.dataframe.sql.Column:
137    @classmethod
138    def invoke_expression_over_column(
139        cls, column: t.Optional[ColumnOrLiteral], callable_expression: t.Callable, **kwargs
140    ) -> Column:
141        ensured_column = None if column is None else cls.ensure_col(column)
142        ensure_expression_values = {
143            k: [Column.ensure_col(x).expression for x in v]
144            if is_iterable(v)
145            else Column.ensure_col(v).expression
146            for k, v in kwargs.items()
147            if v is not None
148        }
149        new_expression = (
150            callable_expression(**ensure_expression_values)
151            if ensured_column is None
152            else callable_expression(
153                this=ensured_column.column_expression, **ensure_expression_values
154            )
155        )
156        return Column(new_expression)
def binary_op( self, klass: Callable, other: <MagicMock id='139719845931152'>, **kwargs) -> sqlglot.dataframe.sql.Column:
158    def binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column:
159        return Column(
160            klass(this=self.column_expression, expression=Column(other).column_expression, **kwargs)
161        )
def inverse_binary_op( self, klass: Callable, other: <MagicMock id='139719845815168'>, **kwargs) -> sqlglot.dataframe.sql.Column:
163    def inverse_binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column:
164        return Column(
165            klass(this=Column(other).column_expression, expression=self.column_expression, **kwargs)
166        )
def unary_op(self, klass: Callable, **kwargs) -> sqlglot.dataframe.sql.Column:
168    def unary_op(self, klass: t.Callable, **kwargs) -> Column:
169        return Column(klass(this=self.column_expression, **kwargs))
@classmethod
def ensure_literal(cls, value) -> sqlglot.dataframe.sql.Column:
187    @classmethod
188    def ensure_literal(cls, value) -> Column:
189        from sqlglot.dataframe.sql.functions import lit
190
191        if isinstance(value, cls):
192            value = value.expression
193        if not isinstance(value, exp.Literal):
194            return lit(value)
195        return Column(value)
def copy(self) -> sqlglot.dataframe.sql.Column:
197    def copy(self) -> Column:
198        return Column(self.expression.copy())
def set_table_name(self, table_name: str, copy=False) -> sqlglot.dataframe.sql.Column:
200    def set_table_name(self, table_name: str, copy=False) -> Column:
201        expression = self.expression.copy() if copy else self.expression
202        expression.set("table", exp.to_identifier(table_name))
203        return Column(expression)
def sql(self, **kwargs) -> str:
205    def sql(self, **kwargs) -> str:
206        return self.expression.sql(**{"dialect": "spark", **kwargs})
def alias(self, name: str) -> sqlglot.dataframe.sql.Column:
208    def alias(self, name: str) -> Column:
209        new_expression = exp.alias_(self.column_expression, name)
210        return Column(new_expression)
def asc(self) -> sqlglot.dataframe.sql.Column:
212    def asc(self) -> Column:
213        new_expression = exp.Ordered(this=self.column_expression, desc=False, nulls_first=True)
214        return Column(new_expression)
def desc(self) -> sqlglot.dataframe.sql.Column:
216    def desc(self) -> Column:
217        new_expression = exp.Ordered(this=self.column_expression, desc=True, nulls_first=False)
218        return Column(new_expression)
def asc_nulls_first(self) -> sqlglot.dataframe.sql.Column:
212    def asc(self) -> Column:
213        new_expression = exp.Ordered(this=self.column_expression, desc=False, nulls_first=True)
214        return Column(new_expression)
def asc_nulls_last(self) -> sqlglot.dataframe.sql.Column:
222    def asc_nulls_last(self) -> Column:
223        new_expression = exp.Ordered(this=self.column_expression, desc=False, nulls_first=False)
224        return Column(new_expression)
def desc_nulls_first(self) -> sqlglot.dataframe.sql.Column:
226    def desc_nulls_first(self) -> Column:
227        new_expression = exp.Ordered(this=self.column_expression, desc=True, nulls_first=True)
228        return Column(new_expression)
def desc_nulls_last(self) -> sqlglot.dataframe.sql.Column:
216    def desc(self) -> Column:
217        new_expression = exp.Ordered(this=self.column_expression, desc=True, nulls_first=False)
218        return Column(new_expression)
def when( self, condition: sqlglot.dataframe.sql.Column, value: Any) -> sqlglot.dataframe.sql.Column:
232    def when(self, condition: Column, value: t.Any) -> Column:
233        from sqlglot.dataframe.sql.functions import when
234
235        column_with_if = when(condition, value)
236        if not isinstance(self.expression, exp.Case):
237            return column_with_if
238        new_column = self.copy()
239        new_column.expression.args["ifs"].extend(column_with_if.expression.args["ifs"])
240        return new_column
def otherwise(self, value: Any) -> sqlglot.dataframe.sql.Column:
242    def otherwise(self, value: t.Any) -> Column:
243        from sqlglot.dataframe.sql.functions import lit
244
245        true_value = value if isinstance(value, Column) else lit(value)
246        new_column = self.copy()
247        new_column.expression.set("default", true_value.column_expression)
248        return new_column
def isNull(self) -> sqlglot.dataframe.sql.Column:
250    def isNull(self) -> Column:
251        new_expression = exp.Is(this=self.column_expression, expression=exp.Null())
252        return Column(new_expression)
def isNotNull(self) -> sqlglot.dataframe.sql.Column:
254    def isNotNull(self) -> Column:
255        new_expression = exp.Not(this=exp.Is(this=self.column_expression, expression=exp.Null()))
256        return Column(new_expression)
def cast(self, dataType: Union[str, sqlglot.dataframe.sql.types.DataType]):
258    def cast(self, dataType: t.Union[str, DataType]):
259        """
260        Functionality Difference: PySpark cast accepts a datatype instance of the datatype class
261        Sqlglot doesn't currently replicate this class so it only accepts a string
262        """
263        if isinstance(dataType, DataType):
264            dataType = dataType.simpleString()
265        return Column(exp.cast(self.column_expression, dataType, dialect="spark"))

Functionality Difference: PySpark cast accepts a datatype instance of the datatype class Sqlglot doesn't currently replicate this class so it only accepts a string

def startswith( self, value: Union[str, sqlglot.dataframe.sql.Column]) -> sqlglot.dataframe.sql.Column:
267    def startswith(self, value: t.Union[str, Column]) -> Column:
268        value = self._lit(value) if not isinstance(value, Column) else value
269        return self.invoke_anonymous_function(self, "STARTSWITH", value)
def endswith( self, value: Union[str, sqlglot.dataframe.sql.Column]) -> sqlglot.dataframe.sql.Column:
271    def endswith(self, value: t.Union[str, Column]) -> Column:
272        value = self._lit(value) if not isinstance(value, Column) else value
273        return self.invoke_anonymous_function(self, "ENDSWITH", value)
def rlike(self, regexp: str) -> sqlglot.dataframe.sql.Column:
275    def rlike(self, regexp: str) -> Column:
276        return self.invoke_expression_over_column(
277            column=self, callable_expression=exp.RegexpLike, expression=self._lit(regexp).expression
278        )
def like(self, other: str):
280    def like(self, other: str):
281        return self.invoke_expression_over_column(
282            self, exp.Like, expression=self._lit(other).expression
283        )
def ilike(self, other: str):
285    def ilike(self, other: str):
286        return self.invoke_expression_over_column(
287            self, exp.ILike, expression=self._lit(other).expression
288        )
def substr( self, startPos: Union[int, sqlglot.dataframe.sql.Column], length: Union[int, sqlglot.dataframe.sql.Column]) -> sqlglot.dataframe.sql.Column:
290    def substr(self, startPos: t.Union[int, Column], length: t.Union[int, Column]) -> Column:
291        startPos = self._lit(startPos) if not isinstance(startPos, Column) else startPos
292        length = self._lit(length) if not isinstance(length, Column) else length
293        return Column.invoke_expression_over_column(
294            self, exp.Substring, start=startPos.expression, length=length.expression
295        )
def isin( self, *cols: Union[<MagicMock id='139719846379728'>, Iterable[<MagicMock id='139719846379728'>]]):
297    def isin(self, *cols: t.Union[ColumnOrLiteral, t.Iterable[ColumnOrLiteral]]):
298        columns = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols  # type: ignore
299        expressions = [self._lit(x).expression for x in columns]
300        return Column.invoke_expression_over_column(self, exp.In, expressions=expressions)  # type: ignore
def between( self, lowerBound: <MagicMock id='139719846507952'>, upperBound: <MagicMock id='139719844481728'>) -> sqlglot.dataframe.sql.Column:
302    def between(
303        self,
304        lowerBound: t.Union[ColumnOrLiteral],
305        upperBound: t.Union[ColumnOrLiteral],
306    ) -> Column:
307        lower_bound_exp = (
308            self._lit(lowerBound) if not isinstance(lowerBound, Column) else lowerBound
309        )
310        upper_bound_exp = (
311            self._lit(upperBound) if not isinstance(upperBound, Column) else upperBound
312        )
313        return Column(
314            exp.Between(
315                this=self.column_expression,
316                low=lower_bound_exp.expression,
317                high=upper_bound_exp.expression,
318            )
319        )
def over( self, window: <MagicMock id='139719844522944'>) -> sqlglot.dataframe.sql.Column:
321    def over(self, window: WindowSpec) -> Column:
322        window_expression = window.expression.copy()
323        window_expression.set("this", self.column_expression)
324        return Column(window_expression)
class DataFrameNaFunctions:
784class DataFrameNaFunctions:
785    def __init__(self, df: DataFrame):
786        self.df = df
787
788    def drop(
789        self,
790        how: str = "any",
791        thresh: t.Optional[int] = None,
792        subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
793    ) -> DataFrame:
794        return self.df.dropna(how=how, thresh=thresh, subset=subset)
795
796    def fill(
797        self,
798        value: t.Union[int, bool, float, str, t.Dict[str, t.Any]],
799        subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
800    ) -> DataFrame:
801        return self.df.fillna(value=value, subset=subset)
802
803    def replace(
804        self,
805        to_replace: t.Union[bool, int, float, str, t.List, t.Dict],
806        value: t.Optional[t.Union[bool, int, float, str, t.List]] = None,
807        subset: t.Optional[t.Union[str, t.List[str]]] = None,
808    ) -> DataFrame:
809        return self.df.replace(to_replace=to_replace, value=value, subset=subset)
DataFrameNaFunctions(df: sqlglot.dataframe.sql.DataFrame)
785    def __init__(self, df: DataFrame):
786        self.df = df
def drop( self, how: str = 'any', thresh: Optional[int] = None, subset: Union[str, Tuple[str, ...], List[str], NoneType] = None) -> sqlglot.dataframe.sql.DataFrame:
788    def drop(
789        self,
790        how: str = "any",
791        thresh: t.Optional[int] = None,
792        subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
793    ) -> DataFrame:
794        return self.df.dropna(how=how, thresh=thresh, subset=subset)
def fill( self, value: Union[int, bool, float, str, Dict[str, Any]], subset: Union[str, Tuple[str, ...], List[str], NoneType] = None) -> sqlglot.dataframe.sql.DataFrame:
796    def fill(
797        self,
798        value: t.Union[int, bool, float, str, t.Dict[str, t.Any]],
799        subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
800    ) -> DataFrame:
801        return self.df.fillna(value=value, subset=subset)
def replace( self, to_replace: Union[bool, int, float, str, List, Dict], value: Union[bool, int, float, str, List, NoneType] = None, subset: Union[str, List[str], NoneType] = None) -> sqlglot.dataframe.sql.DataFrame:
803    def replace(
804        self,
805        to_replace: t.Union[bool, int, float, str, t.List, t.Dict],
806        value: t.Optional[t.Union[bool, int, float, str, t.List]] = None,
807        subset: t.Optional[t.Union[str, t.List[str]]] = None,
808    ) -> DataFrame:
809        return self.df.replace(to_replace=to_replace, value=value, subset=subset)
class Window:
15class Window:
16    _JAVA_MIN_LONG = -(1 << 63)  # -9223372036854775808
17    _JAVA_MAX_LONG = (1 << 63) - 1  # 9223372036854775807
18    _PRECEDING_THRESHOLD = max(-sys.maxsize, _JAVA_MIN_LONG)
19    _FOLLOWING_THRESHOLD = min(sys.maxsize, _JAVA_MAX_LONG)
20
21    unboundedPreceding: int = _JAVA_MIN_LONG
22
23    unboundedFollowing: int = _JAVA_MAX_LONG
24
25    currentRow: int = 0
26
27    @classmethod
28    def partitionBy(cls, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
29        return WindowSpec().partitionBy(*cols)
30
31    @classmethod
32    def orderBy(cls, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
33        return WindowSpec().orderBy(*cols)
34
35    @classmethod
36    def rowsBetween(cls, start: int, end: int) -> WindowSpec:
37        return WindowSpec().rowsBetween(start, end)
38
39    @classmethod
40    def rangeBetween(cls, start: int, end: int) -> WindowSpec:
41        return WindowSpec().rangeBetween(start, end)
@classmethod
def partitionBy( cls, *cols: Union[<MagicMock id='139719845053776'>, List[<MagicMock id='139719845053776'>]]) -> sqlglot.dataframe.sql.WindowSpec:
27    @classmethod
28    def partitionBy(cls, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
29        return WindowSpec().partitionBy(*cols)
@classmethod
def orderBy( cls, *cols: Union[<MagicMock id='139719844989536'>, List[<MagicMock id='139719844989536'>]]) -> sqlglot.dataframe.sql.WindowSpec:
31    @classmethod
32    def orderBy(cls, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
33        return WindowSpec().orderBy(*cols)
@classmethod
def rowsBetween(cls, start: int, end: int) -> sqlglot.dataframe.sql.WindowSpec:
35    @classmethod
36    def rowsBetween(cls, start: int, end: int) -> WindowSpec:
37        return WindowSpec().rowsBetween(start, end)
@classmethod
def rangeBetween(cls, start: int, end: int) -> sqlglot.dataframe.sql.WindowSpec:
39    @classmethod
40    def rangeBetween(cls, start: int, end: int) -> WindowSpec:
41        return WindowSpec().rangeBetween(start, end)
class WindowSpec:
 44class WindowSpec:
 45    def __init__(self, expression: exp.Expression = exp.Window()):
 46        self.expression = expression
 47
 48    def copy(self):
 49        return WindowSpec(self.expression.copy())
 50
 51    def sql(self, **kwargs) -> str:
 52        return self.expression.sql(dialect="spark", **kwargs)
 53
 54    def partitionBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
 55        from sqlglot.dataframe.sql.column import Column
 56
 57        cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols  # type: ignore
 58        expressions = [Column.ensure_col(x).expression for x in cols]
 59        window_spec = self.copy()
 60        partition_by_expressions = window_spec.expression.args.get("partition_by", [])
 61        partition_by_expressions.extend(expressions)
 62        window_spec.expression.set("partition_by", partition_by_expressions)
 63        return window_spec
 64
 65    def orderBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
 66        from sqlglot.dataframe.sql.column import Column
 67
 68        cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols  # type: ignore
 69        expressions = [Column.ensure_col(x).expression for x in cols]
 70        window_spec = self.copy()
 71        if window_spec.expression.args.get("order") is None:
 72            window_spec.expression.set("order", exp.Order(expressions=[]))
 73        order_by = window_spec.expression.args["order"].expressions
 74        order_by.extend(expressions)
 75        window_spec.expression.args["order"].set("expressions", order_by)
 76        return window_spec
 77
 78    def _calc_start_end(
 79        self, start: int, end: int
 80    ) -> t.Dict[str, t.Optional[t.Union[str, exp.Expression]]]:
 81        kwargs: t.Dict[str, t.Optional[t.Union[str, exp.Expression]]] = {
 82            "start_side": None,
 83            "end_side": None,
 84        }
 85        if start == Window.currentRow:
 86            kwargs["start"] = "CURRENT ROW"
 87        else:
 88            kwargs = {
 89                **kwargs,
 90                **{
 91                    "start_side": "PRECEDING",
 92                    "start": "UNBOUNDED"
 93                    if start <= Window.unboundedPreceding
 94                    else F.lit(start).expression,
 95                },
 96            }
 97        if end == Window.currentRow:
 98            kwargs["end"] = "CURRENT ROW"
 99        else:
100            kwargs = {
101                **kwargs,
102                **{
103                    "end_side": "FOLLOWING",
104                    "end": "UNBOUNDED"
105                    if end >= Window.unboundedFollowing
106                    else F.lit(end).expression,
107                },
108            }
109        return kwargs
110
111    def rowsBetween(self, start: int, end: int) -> WindowSpec:
112        window_spec = self.copy()
113        spec = self._calc_start_end(start, end)
114        spec["kind"] = "ROWS"
115        window_spec.expression.set(
116            "spec",
117            exp.WindowSpec(
118                **{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec}
119            ),
120        )
121        return window_spec
122
123    def rangeBetween(self, start: int, end: int) -> WindowSpec:
124        window_spec = self.copy()
125        spec = self._calc_start_end(start, end)
126        spec["kind"] = "RANGE"
127        window_spec.expression.set(
128            "spec",
129            exp.WindowSpec(
130                **{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec}
131            ),
132        )
133        return window_spec
WindowSpec(expression: sqlglot.expressions.Expression = (WINDOW ))
45    def __init__(self, expression: exp.Expression = exp.Window()):
46        self.expression = expression
def copy(self):
48    def copy(self):
49        return WindowSpec(self.expression.copy())
def sql(self, **kwargs) -> str:
51    def sql(self, **kwargs) -> str:
52        return self.expression.sql(dialect="spark", **kwargs)
def partitionBy( self, *cols: Union[<MagicMock id='139719844913712'>, List[<MagicMock id='139719844913712'>]]) -> sqlglot.dataframe.sql.WindowSpec:
54    def partitionBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
55        from sqlglot.dataframe.sql.column import Column
56
57        cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols  # type: ignore
58        expressions = [Column.ensure_col(x).expression for x in cols]
59        window_spec = self.copy()
60        partition_by_expressions = window_spec.expression.args.get("partition_by", [])
61        partition_by_expressions.extend(expressions)
62        window_spec.expression.set("partition_by", partition_by_expressions)
63        return window_spec
def orderBy( self, *cols: Union[<MagicMock id='139719844813248'>, List[<MagicMock id='139719844813248'>]]) -> sqlglot.dataframe.sql.WindowSpec:
65    def orderBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
66        from sqlglot.dataframe.sql.column import Column
67
68        cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols  # type: ignore
69        expressions = [Column.ensure_col(x).expression for x in cols]
70        window_spec = self.copy()
71        if window_spec.expression.args.get("order") is None:
72            window_spec.expression.set("order", exp.Order(expressions=[]))
73        order_by = window_spec.expression.args["order"].expressions
74        order_by.extend(expressions)
75        window_spec.expression.args["order"].set("expressions", order_by)
76        return window_spec
def rowsBetween(self, start: int, end: int) -> sqlglot.dataframe.sql.WindowSpec:
111    def rowsBetween(self, start: int, end: int) -> WindowSpec:
112        window_spec = self.copy()
113        spec = self._calc_start_end(start, end)
114        spec["kind"] = "ROWS"
115        window_spec.expression.set(
116            "spec",
117            exp.WindowSpec(
118                **{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec}
119            ),
120        )
121        return window_spec
def rangeBetween(self, start: int, end: int) -> sqlglot.dataframe.sql.WindowSpec:
123    def rangeBetween(self, start: int, end: int) -> WindowSpec:
124        window_spec = self.copy()
125        spec = self._calc_start_end(start, end)
126        spec["kind"] = "RANGE"
127        window_spec.expression.set(
128            "spec",
129            exp.WindowSpec(
130                **{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec}
131            ),
132        )
133        return window_spec
class DataFrameReader:
15class DataFrameReader:
16    def __init__(self, spark: SparkSession):
17        self.spark = spark
18
19    def table(self, tableName: str) -> DataFrame:
20        from sqlglot.dataframe.sql.dataframe import DataFrame
21
22        sqlglot.schema.add_table(tableName)
23        return DataFrame(
24            self.spark,
25            exp.Select().from_(tableName).select(*sqlglot.schema.column_names(tableName)),
26        )
DataFrameReader(spark: sqlglot.dataframe.sql.SparkSession)
16    def __init__(self, spark: SparkSession):
17        self.spark = spark
def table(self, tableName: str) -> sqlglot.dataframe.sql.DataFrame:
19    def table(self, tableName: str) -> DataFrame:
20        from sqlglot.dataframe.sql.dataframe import DataFrame
21
22        sqlglot.schema.add_table(tableName)
23        return DataFrame(
24            self.spark,
25            exp.Select().from_(tableName).select(*sqlglot.schema.column_names(tableName)),
26        )
class DataFrameWriter:
29class DataFrameWriter:
30    def __init__(
31        self,
32        df: DataFrame,
33        spark: t.Optional[SparkSession] = None,
34        mode: t.Optional[str] = None,
35        by_name: bool = False,
36    ):
37        self._df = df
38        self._spark = spark or df.spark
39        self._mode = mode
40        self._by_name = by_name
41
42    def copy(self, **kwargs) -> DataFrameWriter:
43        return DataFrameWriter(
44            **{
45                k[1:] if k.startswith("_") else k: v
46                for k, v in object_to_dict(self, **kwargs).items()
47            }
48        )
49
50    def sql(self, **kwargs) -> t.List[str]:
51        return self._df.sql(**kwargs)
52
53    def mode(self, saveMode: t.Optional[str]) -> DataFrameWriter:
54        return self.copy(_mode=saveMode)
55
56    @property
57    def byName(self):
58        return self.copy(by_name=True)
59
60    def insertInto(self, tableName: str, overwrite: t.Optional[bool] = None) -> DataFrameWriter:
61        output_expression_container = exp.Insert(
62            **{
63                "this": exp.to_table(tableName),
64                "overwrite": overwrite,
65            }
66        )
67        df = self._df.copy(output_expression_container=output_expression_container)
68        if self._by_name:
69            columns = sqlglot.schema.column_names(tableName, only_visible=True)
70            df = df._convert_leaf_to_cte().select(*columns)
71
72        return self.copy(_df=df)
73
74    def saveAsTable(self, name: str, format: t.Optional[str] = None, mode: t.Optional[str] = None):
75        if format is not None:
76            raise NotImplementedError("Providing Format in the save as table is not supported")
77        exists, replace, mode = None, None, mode or str(self._mode)
78        if mode == "append":
79            return self.insertInto(name)
80        if mode == "ignore":
81            exists = True
82        if mode == "overwrite":
83            replace = True
84        output_expression_container = exp.Create(
85            this=exp.to_table(name),
86            kind="TABLE",
87            exists=exists,
88            replace=replace,
89        )
90        return self.copy(_df=self._df.copy(output_expression_container=output_expression_container))
DataFrameWriter( df: sqlglot.dataframe.sql.DataFrame, spark: Optional[sqlglot.dataframe.sql.SparkSession] = None, mode: Optional[str] = None, by_name: bool = False)
30    def __init__(
31        self,
32        df: DataFrame,
33        spark: t.Optional[SparkSession] = None,
34        mode: t.Optional[str] = None,
35        by_name: bool = False,
36    ):
37        self._df = df
38        self._spark = spark or df.spark
39        self._mode = mode
40        self._by_name = by_name
def copy(self, **kwargs) -> sqlglot.dataframe.sql.DataFrameWriter:
42    def copy(self, **kwargs) -> DataFrameWriter:
43        return DataFrameWriter(
44            **{
45                k[1:] if k.startswith("_") else k: v
46                for k, v in object_to_dict(self, **kwargs).items()
47            }
48        )
def sql(self, **kwargs) -> List[str]:
50    def sql(self, **kwargs) -> t.List[str]:
51        return self._df.sql(**kwargs)
def mode( self, saveMode: Optional[str]) -> sqlglot.dataframe.sql.DataFrameWriter:
53    def mode(self, saveMode: t.Optional[str]) -> DataFrameWriter:
54        return self.copy(_mode=saveMode)
def insertInto( self, tableName: str, overwrite: Optional[bool] = None) -> sqlglot.dataframe.sql.DataFrameWriter:
60    def insertInto(self, tableName: str, overwrite: t.Optional[bool] = None) -> DataFrameWriter:
61        output_expression_container = exp.Insert(
62            **{
63                "this": exp.to_table(tableName),
64                "overwrite": overwrite,
65            }
66        )
67        df = self._df.copy(output_expression_container=output_expression_container)
68        if self._by_name:
69            columns = sqlglot.schema.column_names(tableName, only_visible=True)
70            df = df._convert_leaf_to_cte().select(*columns)
71
72        return self.copy(_df=df)
def saveAsTable( self, name: str, format: Optional[str] = None, mode: Optional[str] = None):
74    def saveAsTable(self, name: str, format: t.Optional[str] = None, mode: t.Optional[str] = None):
75        if format is not None:
76            raise NotImplementedError("Providing Format in the save as table is not supported")
77        exists, replace, mode = None, None, mode or str(self._mode)
78        if mode == "append":
79            return self.insertInto(name)
80        if mode == "ignore":
81            exists = True
82        if mode == "overwrite":
83            replace = True
84        output_expression_container = exp.Create(
85            this=exp.to_table(name),
86            kind="TABLE",
87            exists=exists,
88            replace=replace,
89        )
90        return self.copy(_df=self._df.copy(output_expression_container=output_expression_container))